AutoMR-pangu / automr /utils.py
haifei
code and checkpoint
1482463
import re
import os
import json
from typing import Any
import re
import regex
from latex2sympy2 import latex2sympy
from word2number import w2n
def extract_math_answer(text: str) -> str:
"""Extract answer from math problem (boxed format)"""
# Try to find \boxed{...} format
boxed_pattern = r'\\boxed\{([^}]*)\}'
matches = re.findall(boxed_pattern, text)
if matches:
return matches[-1].strip()
# Try to find answer after "answer is" or similar phrases
answer_patterns = [
r'answer is[:\s]+([^\n.]+)',
r'final answer[:\s]+([^\n.]+)',
r'therefore[,:\s]+([^\n.]+)'
]
for pattern in answer_patterns:
matches = re.findall(pattern, text.lower())
if matches:
return matches[-1].strip()
return text.strip()
def extract_multiple_choice_answer(text: str) -> str:
"""Extract answer from multiple choice (A, B, C, D format)"""
# Look for single letter answers
pattern = r'\b([A-D])\b'
matches = re.findall(pattern, text.upper())
if matches:
return matches[-1]
return text.strip()
def normalize_answer(answer: str) -> str:
"""Normalize answer for comparison"""
answer = answer.strip().lower()
# Remove common mathematical notation
answer = answer.replace('$', '').replace('\\', '')
# Remove extra whitespace
answer = ' '.join(answer.split())
return answer
def check_answer_match(pred: str, ground_truth: str, task_type: str = "math") -> bool:
"""Check if predicted answer matches ground truth"""
# if task_type == "math":
# pred = extract_math_answer(pred)
# ground_truth = extract_math_answer(ground_truth)
# elif task_type == "multiple_choice":
# pred = extract_multiple_choice_answer(pred)
# ground_truth = extract_multiple_choice_answer(ground_truth)
# pred_norm = normalize_answer(pred)
# gt_norm = normalize_answer(ground_truth)
# return pred_norm == gt_norm or pred_norm in gt_norm or gt_norm in pred_norm
pred = str(pred)
ground_truth = str(ground_truth)
return pred == ground_truth or pred in ground_truth or ground_truth in pred
def ensure_dir(directory: str):
"""Create directory if it doesn't exist"""
if not os.path.exists(directory):
os.makedirs(directory)
def save_json(data: Any, path: str):
"""Save data to JSON file"""
ensure_dir(os.path.dirname(path))
with open(path, 'w') as f:
json.dump(data, f, indent=2)
def load_json(path: str) -> Any:
"""Load data from JSON file"""
with open(path, 'r') as f:
return json.load(f)
def _fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if len(substr) > 0 and substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def _fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
if "sqrt" not in a:
a = int(a)
if "sqrt" not in b:
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except:
return string
def _fix_sqrt(string):
_string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string)
return _string
def convert_word_number(text: str) -> str:
try:
text = str(w2n.word_to_num(text))
except:
pass
return text
# units mainly from MathQA
unit_texts = [
"east",
"degree",
"mph",
"kmph",
"ft",
"m sqaure",
" m east",
"sq m",
"deg",
"mile",
"q .",
"monkey",
"prime",
"ratio",
"profit of rs",
"rd",
"o",
"gm",
"p . m",
"lb",
"tile",
"per",
"dm",
"lt",
"gain",
"ab",
"way",
"west",
"a .",
"b .",
"c .",
"d .",
"e .",
"f .",
"g .",
"h .",
"t",
"a",
"h",
"no change",
"men",
"soldier",
"pie",
"bc",
"excess",
"st",
"inches",
"noon",
"percent",
"by",
"gal",
"kmh",
"c",
"acre",
"rise",
"a . m",
"th",
"π r 2",
"sq",
"mark",
"l",
"toy",
"coin",
"sq . m",
"gallon",
"° f",
"profit",
"minw",
"yr",
"women",
"feet",
"am",
"pm",
"hr",
"cu cm",
"square",
"v â € ™",
"are",
"rupee",
"rounds",
"cubic",
"cc",
"mtr",
"s",
"ohm",
"number",
"kmph",
"day",
"hour",
"minute",
"min",
"second",
"man",
"woman",
"sec",
"cube",
"mt",
"sq inch",
"mp",
"∏ cm ³",
"hectare",
"more",
"sec",
"unit",
"cu . m",
"cm 2",
"rs .",
"rs",
"kg",
"g",
"month",
"km",
"m",
"cm",
"mm",
"apple",
"liter",
"loss",
"yard",
"pure",
"year",
"increase",
"decrease",
"d",
"less",
"Surface",
"litre",
"pi sq m",
"s .",
"metre",
"meter",
"inch",
]
unit_texts.extend([t + "s" for t in unit_texts])
def strip_string(string, skip_unit=False):
string = str(string).strip()
# linebreaks
string = string.replace("\n", "")
# right "."
string = string.rstrip(".")
# remove inverse spaces
# replace \\ with \
string = string.replace("\\!", "")
# string = string.replace("\\ ", "")
# string = string.replace("\\\\", "\\")
# matrix
string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string)
string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string)
string = string.replace("bmatrix", "pmatrix")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
string = (string.replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge"))
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
string = string.replace("\\{", "{")
string = string.replace("\\}", "}")
# Remove unit: miles, dollars if after is not none
_string = re.sub(r"\\text{.*?}$", "", string).strip()
if _string != "" and _string != string:
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
string = _string
if not skip_unit:
# Remove unit: texts
for _ in range(2):
for unit_text in unit_texts:
# use regex, the prefix should be either the start of the string or a non-alphanumeric character
# the suffix should be either the end of the string or a non-alphanumeric character
_string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string)
if _string != "":
string = _string
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
string = string.replace("$", "")
string = string.replace("\\(", "").replace("\\)", "")
# convert word number to digit
string = convert_word_number(string)
# replace "\\text{...}" to "..."
string = re.sub(r"\\text\{(.*?)\}", r"\1", string)
for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]:
string = string.replace(key, "")
string = string.replace("\\emptyset", r"{}")
string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}")
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "")
string = string.replace("%", "")
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# cdot
# string = string.replace("\\cdot", "")
if (string.startswith("{") and string.endswith("}") and string.isalnum() or
string.startswith("(") and string.endswith(")") and string.isalnum() or
string.startswith("[") and string.endswith("]") and string.isalnum()):
string = string[1:-1]
# inf
string = string.replace("infinity", "\\infty")
if "\\infty" not in string:
string = string.replace("inf", "\\infty")
string = string.replace("+\\inity", "\\infty")
# and
string = string.replace("and", "")
string = string.replace("\\mathbf", "")
# use regex to remove \mbox{...}
string = re.sub(r"\\mbox{.*?}", "", string)
# quote
string.replace("'", "")
string.replace('"', "")
# i, j
if "j" in string and "i" not in string:
string = string.replace("j", "i")
# replace a.000b where b is not number or b is end, with ab, use regex
string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string)
string = re.sub(r"(\d+)\.0*$", r"\1", string)
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
string = _fix_sqrt(string)
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)
return string
direct_answer_trigger_for_fewshot = ("choice is", "answer is")
def choice_answer_clean(pred: str):
pred = pred.strip("\n")
# Determine if this is ICL, if so, use \n\n to split the first chunk.
ICL = False
for trigger in direct_answer_trigger_for_fewshot:
if pred.count(trigger) > 1:
ICL = True
if ICL:
pred = pred.split("\n\n")[0]
# Split the trigger to find the answer.
preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred)
if len(preds) > 1:
answer_flag = True
pred = preds[-1]
else:
answer_flag = False
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
# Clean the answer based on the dataset
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
if tmp:
pred = tmp
else:
pred = [pred.strip().strip(".")]
if len(pred) == 0:
pred = ""
else:
if answer_flag:
# choose the first element in list ...
pred = pred[0]
else:
# choose the last e
pred = pred[-1]
# Remove the period at the end, again!
pred = pred.rstrip(".").rstrip("/")
return pred
def find_box(pred_str: str):
ans = pred_str.split("boxed")[-1]
if not ans:
return ""
if ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
return a
def clean_units(pred_str: str):
"""Clean the units in the number."""
def convert_pi_to_number(code_string):
code_string = code_string.replace("\\pi", "π")
# Replace \pi or π not preceded by a digit or } with 3.14
code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string)
# Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14"
code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string)
# Handle cases where π is within braces or followed by a multiplication symbol
# This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14"
code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string)
code_string = re.sub(r"\*(\\?π)", "*3.14", code_string)
return code_string
pred_str = convert_pi_to_number(pred_str)
pred_str = pred_str.replace("%", "/100")
pred_str = pred_str.replace("$", "")
pred_str = pred_str.replace("¥", "")
pred_str = pred_str.replace("°C", "")
pred_str = pred_str.replace(" C", "")
pred_str = pred_str.replace("°", "")
return pred_str
def extract_theoremqa_answer(pred: str, answer_flag: bool = True):
if any([option in pred.lower() for option in ["yes", "true"]]):
pred = "True"
elif any([option in pred.lower() for option in ["no", "false"]]):
pred = "False"
elif any([option in pred.lower() for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"]]):
pass
else:
# Some of the models somehow get used to boxed output from pre-training
if "boxed" in pred:
pred = find_box(pred)
if answer_flag:
# Extract the numbers out of the string
pred = pred.split("=")[-1].strip()
pred = clean_units(pred)
try:
tmp = str(latex2sympy(pred))
pred = str(eval(tmp))
except Exception:
if re.match(r"-?[\d\.]+\s\D+$", pred):
pred = pred.split(" ")[0]
elif re.match(r"-?[\d\.]+\s[^\s]+$", pred):
pred = pred.split(" ")[0]
else:
# desparate search over the last number
preds = re.findall(r"-?\d*\.?\d+", pred)
if len(preds) >= 1:
pred = preds[-1]
else:
pred = ""
return pred
def extract_answer(pred_str, data_name, use_last_number=True):
if data_name.lower() == "humaneval":
pattern = r"### Function Body:\s*\n```python\n(.*?)\n```"
matches = re.findall(pattern, pred_str, re.DOTALL)
try:
return matches[0]
except IndexError:
return ""
elif data_name.lower() == "mmlu":
if len(pred_str) >= 3 and pred_str[0] == '(' and pred_str[2] == ')':
return pred_str[1]
pred_str = pred_str.replace("\u043a\u0438", "")
if "final answer is $" in pred_str and "$. I hope" in pred_str:
# minerva_math
tmp = pred_str.split("final answer is $", 1)[1]
pred = tmp.split("$. I hope", 1)[0].strip()
elif "boxed" in pred_str:
ans = pred_str.split("boxed")[-1]
if len(ans) == 0:
return ""
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
else:
a = ans.split("$")[0].strip()
pred = a
elif "he answer is" in pred_str:
pred = pred_str.split("he answer is")[-1].strip()
elif "final answer is" in pred_str:
pred = pred_str.split("final answer is")[-1].strip()
elif "答案是" in pred_str:
# Handle Chinese few-shot multiple choice problem answer extraction
pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip()
else: # use the last number
if use_last_number:
pattern = "-?\d*\.?\d+"
pred = re.findall(pattern, pred_str.replace(",", ""))
if len(pred) >= 1:
pred = pred[-1]
else:
pred = ""
else:
pred = ""
# multiple line
# pred = pred.split("\n")[0]
pred = re.sub(r"\n\s*", "", pred)
if pred != "" and pred[0] == ":":
pred = pred[1:]
if pred != "" and pred[-1] == ".":
pred = pred[:-1]
if pred != "" and pred[-1] == "/":
pred = pred[:-1]
pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva"])
if data_name == 'GPQA' or data_name == 'MMLU':
if len(pred) >= 3 and pred[0] == '(' and pred[2] == ')':
pred = pred[1]
return pred
STRIP_EXCEPTIONS = ["carp_en", "minerva"]
def parse_ground_truth(groudtruth_solution: str, data_name):
gt_ans = extract_answer(groudtruth_solution, data_name)
return gt_ans
def parse_question(example, data_name):
question = ""
if data_name == "asdiv":
question = f"{example['body'].strip()} {example['question'].strip()}"
elif data_name == "svamp":
body = example["Body"].strip()
if not body.endswith("."):
body = body + "."
question = f'{body} {example["Question"].strip()}'
elif data_name == "tabmwp":
title_str = (f'regarding "{example["table_title"]}" ' if example["table_title"] else "")
question = f"Read the following table {title_str}and answer a question:\n"
question += f'{example["table"]}\n{example["question"]}'
if example["choices"]:
question += (f' Please select from the following options: {example["choices"]}')
elif data_name == "carp_en":
question = example["content"]
elif data_name == "mmlu_stem":
options = example["choices"]
assert len(options) == 4
for i, (label, option) in enumerate(zip("ABCD", options)):
options[i] = f"({label}) {str(option).strip()}"
options = " ".join(options)
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
elif data_name == "sat_math":
options = example["options"].strip()
assert "A" == options[0]
options = "(" + options
for ch in "BCD":
if f" {ch}) " in options:
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
# question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
question = f"{example['question'].strip()}\nAnswer Choices: {options}"
elif "aqua" in data_name:
options = example["options"]
choice = "(" + "(".join(options)
choice = choice.replace("(", " (").replace(")", ") ").strip()
choice = "\nAnswer Choices: " + choice
question = example["question"].strip() + choice
elif data_name == "gaokao_math_qa":
options_dict = example["options"]
options = []
for key in options_dict:
options.append(f"({key}) {options_dict[key]}")
options = " ".join(options)
question = f"{example['question'].strip()}\n选项: {options}"
else:
for key in ["question", "problem", "Question", "input"]:
if key in example:
question = example[key]
break
# assert question != ""
# Yes or No question
_, gt_ans = parse_ground_truth(example, data_name)
if isinstance(gt_ans, str):
gt_lower = gt_ans.lower()
if gt_lower in ["true", "false"]:
question += " (True or False)"
if gt_lower in ["yes", "no"]:
question += " (Yes or No)"
return question.strip()