| |
| |
| |
| import re |
| from typing import Dict, List |
|
|
| import numpy as np |
|
|
|
|
| def parse_math_answer(raw_string): |
| def remove_boxed(s): |
| left = "\\boxed{" |
| try: |
| assert s[: len(left)] == left |
| assert s[-1] == "}" |
| answer = s[len(left) : -1] |
| if "=" in answer: |
| answer = answer.split("=")[-1].lstrip(" ") |
| return answer |
| except Exception: |
| return None |
|
|
| def last_boxed_only_string(string): |
| idx = string.rfind("\\boxed") |
| if idx < 0: |
| idx = string.rfind("\\fbox") |
| if idx < 0: |
| return None |
| i = idx |
| right_brace_idx = None |
| num_left_braces_open = 0 |
| while i < len(string): |
| if string[i] == "{": |
| num_left_braces_open += 1 |
| if string[i] == "}": |
| num_left_braces_open -= 1 |
| if num_left_braces_open == 0: |
| right_brace_idx = i |
| break |
| i += 1 |
|
|
| if right_brace_idx is None: |
| retval = None |
| else: |
| retval = string[idx : right_brace_idx + 1] |
|
|
| return retval |
|
|
| def get_answer_with_dollar_sign(s): |
| first_pattern = "\$(.*)\$" |
| last_match = None |
| matches = re.findall(first_pattern, s) |
| if matches: |
| last_match = matches[-1] |
| if "=" in last_match: |
| last_match = last_match.split("=")[-1].lstrip(" ") |
| return last_match |
|
|
| def get_answer_without_dollar_sign(s): |
| last_match = None |
| if "=" in s: |
| last_match = s.split("=")[-1].lstrip(" ").rstrip(".") |
| if "\\n" in last_match: |
| last_match = last_match.split("\\n")[0] |
| else: |
| pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])" |
| matches = re.findall(pattern, s) |
| if matches: |
| last_match = matches[-1] |
| return last_match |
|
|
| if "\\boxed" in raw_string: |
| answer = remove_boxed(last_boxed_only_string(raw_string)) |
| else: |
| answer = get_answer_with_dollar_sign(raw_string) |
| if not answer: |
| answer = get_answer_without_dollar_sign(raw_string) |
| return answer |
|
|
|
|
| |
| def _fix_fracs(string): |
| substrs = string.split("\\frac") |
| new_str = substrs[0] |
| if len(substrs) > 1: |
| substrs = substrs[1:] |
| for substr in substrs: |
| new_str += "\\frac" |
| if substr[0] == "{": |
| new_str += substr |
| else: |
| try: |
| assert len(substr) >= 2 |
| except Exception: |
| return string |
| a = substr[0] |
| b = substr[1] |
| if b != "{": |
| if len(substr) > 2: |
| post_substr = substr[2:] |
| new_str += "{" + a + "}{" + b + "}" + post_substr |
| else: |
| new_str += "{" + a + "}{" + b + "}" |
| else: |
| if len(substr) > 2: |
| post_substr = substr[2:] |
| new_str += "{" + a + "}" + b + post_substr |
| else: |
| new_str += "{" + a + "}" + b |
| string = new_str |
| return string |
|
|
|
|
| def _fix_a_slash_b(string): |
| if len(string.split("/")) != 2: |
| return string |
| a = string.split("/")[0] |
| b = string.split("/")[1] |
| try: |
| a = int(a) |
| b = int(b) |
| assert string == "{}/{}".format(a, b) |
| new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" |
| return new_string |
| except Exception: |
| return string |
|
|
|
|
| def _remove_right_units(string): |
| |
| if "\\text{ " in string: |
| splits = string.split("\\text{ ") |
| assert len(splits) == 2 |
| return splits[0] |
| else: |
| return string |
|
|
|
|
| def _fix_sqrt(string): |
| if "\\sqrt" not in string: |
| return string |
| splits = string.split("\\sqrt") |
| new_string = splits[0] |
| for split in splits[1:]: |
| if split[0] != "{": |
| a = split[0] |
| new_substr = "\\sqrt{" + a + "}" + split[1:] |
| else: |
| new_substr = "\\sqrt" + split |
| new_string += new_substr |
| return new_string |
|
|
|
|
| def _strip_string(string): |
| |
| string = string.replace("\n", "") |
| |
|
|
| |
| string = string.replace("\\!", "") |
| |
|
|
| |
| string = string.replace("\\\\", "\\") |
| |
|
|
| |
| string = string.replace("tfrac", "frac") |
| string = string.replace("dfrac", "frac") |
| |
|
|
| |
| string = string.replace("\\left", "") |
| string = string.replace("\\right", "") |
| |
|
|
| |
| string = string.replace("^{\\circ}", "") |
| string = string.replace("^\\circ", "") |
|
|
| |
| string = string.replace("\\$", "") |
|
|
| |
| string = _remove_right_units(string) |
|
|
| |
| string = string.replace("\\%", "") |
| string = string.replace("\%", "") |
|
|
| |
| string = string.replace(" .", " 0.") |
| string = string.replace("{.", "{0.") |
| |
| if len(string) == 0: |
| return string |
| if string[0] == ".": |
| string = "0" + string |
|
|
| |
| if len(string.split("=")) == 2: |
| if len(string.split("=")[0]) <= 2: |
| string = string.split("=")[1] |
|
|
| |
| string = _fix_sqrt(string) |
|
|
| |
| string = string.replace(" ", "") |
|
|
| |
| string = _fix_fracs(string) |
|
|
| |
| if string == "0.5": |
| string = "\\frac{1}{2}" |
|
|
| |
| string = _fix_a_slash_b(string) |
|
|
| return string |
|
|
|
|
| def is_equiv(str1, str2, verbose=False): |
| if str1 is None and str2 is None: |
| print("WARNING: Both None") |
| return True |
| if str1 is None or str2 is None: |
| return False |
|
|
| str1, str2 = parse_math_answer(str1), parse_math_answer(str2) |
|
|
| try: |
| ss1 = _strip_string(str1) |
| ss2 = _strip_string(str2) |
| if verbose: |
| print(ss1, ss2) |
| return ss1 == ss2 |
| except Exception: |
| return str1 == str2 |
|
|
|
|
| def process_results(doc: dict, results: List[str]) -> Dict[str, int]: |
| candidate = results[0] |
|
|
| gold = doc["answer"] |
|
|
| if not gold: |
| print(doc, candidate, gold) |
| if is_equiv(candidate, gold): |
| retval = 1 |
| else: |
| retval = 0 |
|
|
| results = { |
| "acc": retval, |
| } |
| return results |
|
|
|
|
| |
| def process_results_mcqa(doc, results): |
| results = [result[0] for result in results] |
|
|
| gold = doc["gold"] |
|
|
| acc = 1.0 if int(np.argmax(results)) in gold else 0.0 |
| completion_len = np.array([float(len(i)) for i in doc["choices"]]) |
| acc_norm = 1.0 if int(np.argmax(results / completion_len)) in gold else 0.0 |
|
|
| return { |
| "acc": acc, |
| "acc_norm": acc_norm, |
| } |
|
|