|
|
import regex |
|
|
from typing import Optional |
|
|
from loguru import logger |
|
|
import re |
|
|
from fractions import Fraction |
|
|
from matharena.utils import latex2sympy_fixed |
|
|
from sympy import N, Integer |
|
|
import sympy |
|
|
from typing import Any |
|
|
from enum import Enum |
|
|
from functools import total_ordering |
|
|
from matharena.parse_manual import manual_mapper |
|
|
|
|
|
@total_ordering |
|
|
class WarningType(Enum): |
|
|
NONE = 0 |
|
|
MINOR = 1 |
|
|
POSSIBLE = 2 |
|
|
MAJOR = 3 |
|
|
def __lt__(self, other): |
|
|
if self.__class__ is other.__class__: |
|
|
return self.value < other.value |
|
|
return self.value < other |
|
|
|
|
|
|
|
|
def parse_grading(text: str): |
|
|
pattern = re.compile(r"""Category:\s*(.*?)(?:\s|\n)*Points\s+awarded:\s*(.*?)(?:\s|\n)*Description:\s*(.*?)(?:\s|\n)*(?=Category:|$)""", re.DOTALL | re.VERBOSE) |
|
|
matches = pattern.findall(text) |
|
|
|
|
|
result = {"points": sum(int(points) for _, points, _ in matches), "details": []} |
|
|
|
|
|
for title, points, desc in matches: |
|
|
result["details"].append({ |
|
|
"title": title, |
|
|
"points": int(points), |
|
|
"desc": desc |
|
|
}) |
|
|
|
|
|
return result |
|
|
|
|
|
def find_last_boxed_content(text: str, list_answer: bool = False) -> Optional[str]: |
|
|
pattern = r"(boxed|fbox)\{((?:[^{}]|\{(?2)\})*)\}" |
|
|
matches = list(regex.finditer(pattern, text)) |
|
|
if not matches: |
|
|
return None, WarningType.NONE |
|
|
|
|
|
if len(matches) > 1 and list_answer: |
|
|
|
|
|
split_text = text.split("\n") |
|
|
for i in range(len(split_text)-1, -1, -1): |
|
|
matches_line = list(regex.finditer(pattern, split_text[i])) |
|
|
if len(matches_line) > 0: |
|
|
returned_boxed = ",".join([match.group(2) for match in matches_line]) |
|
|
return returned_boxed, WarningType.POSSIBLE |
|
|
|
|
|
last_match = matches[-1] |
|
|
return last_match.group(2), WarningType.NONE |
|
|
|
|
|
|
|
|
def extract_boxed_answer(text: str, list_answer: bool = False) -> Optional[str]: |
|
|
answer, warning = find_last_boxed_content(text, list_answer) |
|
|
if answer is not None and "=" in answer: |
|
|
answer = answer.split("=")[-1] |
|
|
if answer is not None: |
|
|
return answer, warning |
|
|
else: |
|
|
return None, warning |
|
|
|
|
|
|
|
|
def extract_boxed_answer_parse(text: str, parse: bool = True, list_answer: bool = False) -> Optional[int]: |
|
|
answer, warning = extract_boxed_answer(text, list_answer) |
|
|
if answer is not None: |
|
|
try: |
|
|
return sympy.Integer(int(answer)), warning |
|
|
except: |
|
|
|
|
|
if parse: |
|
|
parsed_answer, warning = parse_answer(answer) |
|
|
return parsed_answer, warning |
|
|
return answer, warning |
|
|
return None, WarningType.MAJOR |
|
|
|
|
|
def extract_last_integer(text: str) -> Optional[int]: |
|
|
pattern = r"\b\d+\b" |
|
|
matches = list(regex.finditer(pattern, text)) |
|
|
if not matches: |
|
|
return None, WarningType.MAJOR |
|
|
try: |
|
|
return int(matches[-1].group()), WarningType.MAJOR |
|
|
except: |
|
|
return None, WarningType.MAJOR |
|
|
|
|
|
|
|
|
def extract_answer(text: str, strict_parsing: bool = True, parse: bool = True, list_answer: bool = False): |
|
|
text, warning = replace_unicode(text) |
|
|
answer, warning_new = extract_boxed_answer_parse(text, parse, list_answer) |
|
|
warning = max(warning, warning_new) |
|
|
if answer is not None or strict_parsing: |
|
|
return answer, warning |
|
|
|
|
|
return extract_last_integer(text) |
|
|
|
|
|
def parse_answer(s: str, primitive_type: type = None): |
|
|
warning = WarningType.NONE |
|
|
if s in manual_mapper: |
|
|
logger.warning(f"Applying manual parsing to {s}") |
|
|
s = manual_mapper[s] |
|
|
warning = WarningType.MAJOR |
|
|
s = remove_invalid_characters(s) |
|
|
s = remove_outer_brackets(normalize_string(s)) |
|
|
output, warning_new = ParseList.parse("(" + s + ")", primitive_type=primitive_type) |
|
|
warning = max(warning, warning_new) |
|
|
if output is None: |
|
|
logger.warning(f"Could not parse {s}, returning None") |
|
|
return None, max(warning, WarningType.MAJOR) |
|
|
if len(output) == 1: |
|
|
output = output[0] |
|
|
|
|
|
if isinstance(output, list) or isinstance(output, tuple): |
|
|
output = AnswerList(output) |
|
|
return output, warning |
|
|
|
|
|
def normalize_string(s): |
|
|
s = s.replace(r"\left", "").replace(r"\right", "") |
|
|
s = s.replace(r"\Bigl", "").replace(r"\Bigr", "") |
|
|
s = s.replace(r"\bigl", "").replace(r"\bigr", "") |
|
|
s = remove_aligns(s) |
|
|
s = s.replace("[", "(") |
|
|
s = s.replace("]", ")") |
|
|
s = s.replace("\\{", "(") |
|
|
s = s.replace("\\}", ")") |
|
|
s = s.replace("$", "") |
|
|
|
|
|
s = s.replace(r"\hline", "") |
|
|
s = s.replace(r"\vline", "") |
|
|
return strip(s) |
|
|
|
|
|
def remove_outer_brackets(s): |
|
|
""" |
|
|
Removes the outermost matching brackets from the string if they encompass the entire string. |
|
|
|
|
|
Parameters: |
|
|
s (str): The input string potentially wrapped with brackets. |
|
|
|
|
|
Returns: |
|
|
str: The string with the outermost brackets removed if they match and encompass the entire string. |
|
|
""" |
|
|
while True: |
|
|
if not s: |
|
|
return s |
|
|
opening = s[0] |
|
|
closing = s[-1] |
|
|
|
|
|
if opening == "(" and closing == ")": |
|
|
count = 0 |
|
|
matched = True |
|
|
for i, char in enumerate(s): |
|
|
if char == opening: |
|
|
count += 1 |
|
|
elif char == closing: |
|
|
count -= 1 |
|
|
if count == 0 and i != len(s) - 1: |
|
|
matched = False |
|
|
break |
|
|
|
|
|
if matched: |
|
|
s = s[1:-1] |
|
|
continue |
|
|
break |
|
|
|
|
|
return s |
|
|
|
|
|
def remove_aligns(s: str) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pattern = r'\\begin{align[^}]*}(.*?)\\end{align[^}]*}' |
|
|
|
|
|
|
|
|
return re.sub(pattern, lambda m: m.group(1).replace('&', '').replace("\\\\", ""), |
|
|
s, flags=re.DOTALL) |
|
|
|
|
|
def replace_unicode(text: str) -> str: |
|
|
text_old = text |
|
|
text = text.replace("\u23a7", r"\boxed{") |
|
|
text = text.replace("\u23ab", r"}") |
|
|
text = text.replace("\n\u2502", r"\boxed{") |
|
|
text = text.replace("\u2502", r"}") |
|
|
text = text.replace("\n\u2503", r"\boxed{") |
|
|
text = text.replace("\u2503", r"}") |
|
|
text = text.replace("\n\uf8f0", r"\boxed{") |
|
|
text = text.replace("\uf8fb", r"}") |
|
|
warning = WarningType.NONE if text == text_old else WarningType.POSSIBLE |
|
|
text = text.replace("\u221a", r"\sqrt") |
|
|
text = text.replace("\u00d7", r"\cdot") |
|
|
text = text.replace("\u202f", r" ") |
|
|
return text, warning |
|
|
|
|
|
def remove_invalid_characters(text): |
|
|
text = re.sub(r'\\;', '', text) |
|
|
text = re.sub(r'\\,', '', text) |
|
|
|
|
|
return text |
|
|
|
|
|
def strip(s: str): |
|
|
s = s.strip() |
|
|
|
|
|
while s.startswith(r"\n"): |
|
|
s = s[2:] |
|
|
while s.endswith(r"\n"): |
|
|
s = s[:-2] |
|
|
while s.startswith("\\ "): |
|
|
s = s[2:] |
|
|
|
|
|
while re.match(r"\\{2,}\n?\(", s): |
|
|
s = s[3:] |
|
|
return s |
|
|
|
|
|
def check_answers(ans1, ans2): |
|
|
if ans1 is None or ans2 is None: |
|
|
return False |
|
|
if (type(ans1) in [list, AnswerList]) != (type(ans2) in [list, AnswerList]): |
|
|
return False |
|
|
try: |
|
|
if not (hasattr(ans1, 'equals') and callable(ans1.equals)) \ |
|
|
or not (hasattr(ans2, 'equals') and callable(ans2.equals)): |
|
|
|
|
|
if isinstance(ans1, str) or isinstance(ans2, str): |
|
|
return ans1 == ans2 |
|
|
if abs(ans1 - ans2) < 10 ** -10: |
|
|
return True |
|
|
return False |
|
|
return ans1.equals(ans2) |
|
|
except Exception as e: |
|
|
logger.error(f'Could not establish equality for answers {ans1} and {ans2}, error: {e}') |
|
|
return False |
|
|
|
|
|
class AnswerList: |
|
|
def __init__(self, answers: list[Any]): |
|
|
if not isinstance(answers, list) and not isinstance(answers, tuple): |
|
|
raise ValueError(f"Expected passed answers to be list or tuple, received {type(answers)}") |
|
|
|
|
|
valid_answers = [] |
|
|
for answer in answers: |
|
|
if bool(re.search(r'\d', str(answer))): |
|
|
valid_answers.append(answer) |
|
|
else: |
|
|
logger.warning(f'Could not find any numbers in {answer}, removed from list') |
|
|
|
|
|
self.answers = list(valid_answers) |
|
|
|
|
|
def equals(self, other: list[Any]): |
|
|
if len(self.answers) != len(other): |
|
|
|
|
|
return False |
|
|
|
|
|
match_ids = set() |
|
|
for ans1 in self.answers: |
|
|
match_found = False |
|
|
for i, ans2 in enumerate(other): |
|
|
if i not in match_ids and check_answers(ans1, ans2): |
|
|
match_ids.add(i) |
|
|
match_found = True |
|
|
break |
|
|
if not match_found: |
|
|
|
|
|
return False |
|
|
return True |
|
|
|
|
|
def __str__(self): |
|
|
return '[' + ','.join([str(ans) for ans in self.answers]) + ']' |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.answers) |
|
|
|
|
|
def __iter__(self): |
|
|
return iter(self.answers) |
|
|
|
|
|
class ParseObject: |
|
|
@classmethod |
|
|
def is_at_start(cls, string): |
|
|
return False |
|
|
|
|
|
@classmethod |
|
|
def is_complete(cls, string): |
|
|
return string.count("{") == string.count("}") and string.count("(") == string.count(")") |
|
|
|
|
|
@classmethod |
|
|
def is_finished(cls, string): |
|
|
return True |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, string): |
|
|
raise NotImplementedError |
|
|
|
|
|
class ParsePrimitive(ParseObject): |
|
|
@classmethod |
|
|
def parse(cls, string, primitive_type): |
|
|
warning = WarningType.NONE |
|
|
|
|
|
if string.isdigit(): |
|
|
if primitive_type == Fraction: |
|
|
return Fraction(int(string), 1) |
|
|
return int(string), warning |
|
|
|
|
|
try: |
|
|
float_string = float(string) |
|
|
if int(float_string) == float_string: |
|
|
if primitive_type == Fraction: |
|
|
return Fraction(int(float_string), 1) |
|
|
return int(float_string), warning |
|
|
return float_string, warning |
|
|
except ValueError: |
|
|
|
|
|
pass |
|
|
|
|
|
if bool(re.search(r'sqrt(\d+)', string)): |
|
|
string = re.sub(r'sqrt(\d+)', r'sqrt{\1}', string) |
|
|
try: |
|
|
latex_str = string |
|
|
for _ in range(5): |
|
|
init_str = latex_str |
|
|
latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str) |
|
|
latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str) |
|
|
latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str) |
|
|
latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str) |
|
|
|
|
|
latex_str = latex_str.replace('^', '**') |
|
|
latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*') |
|
|
latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I') |
|
|
|
|
|
if init_str == latex_str: |
|
|
break |
|
|
|
|
|
for _ in range(5): |
|
|
|
|
|
init_str = latex_str |
|
|
latex_str = re.sub(r'\{(\d+)\}', r'(\1)', latex_str) |
|
|
latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str) |
|
|
latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str) |
|
|
latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str) |
|
|
latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str) |
|
|
|
|
|
latex_str = latex_str.replace('^', '**') |
|
|
latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*') |
|
|
latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I') |
|
|
if init_str == latex_str: |
|
|
break |
|
|
|
|
|
|
|
|
latex_str = re.sub(r'(\d|(?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))\(', r'\1*(', latex_str) |
|
|
latex_str = re.sub(r'\)(\d|(?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))', r')*\1', latex_str) |
|
|
latex_str = re.sub(r'(?<=\d)((?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))', r'*\1', latex_str) |
|
|
latex_str = re.sub(r'((?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))(?=\d)', r'\1*', latex_str) |
|
|
latex_str = re.sub(r'\{([^{}]*)\}', lambda m: '[' + m.group(1).replace(',', ', ') + ']', |
|
|
latex_str) |
|
|
|
|
|
string = sympy.sympify(latex_str, |
|
|
locals={'binomial': sympy.binomial, |
|
|
'pi': sympy.pi, |
|
|
'E': sympy.E, |
|
|
'I': sympy.I} |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
try: |
|
|
string_no_eq = string |
|
|
if "=" in string_no_eq: |
|
|
|
|
|
string_no_eq = string_no_eq[string_no_eq.rfind("=")+1:] |
|
|
float_val = float(N(latex2sympy_fixed(string_no_eq), 101)) |
|
|
if float_val.is_integer() or float("inf") == float_val or float("-inf") == float_val: |
|
|
return int(N(latex2sympy_fixed(string_no_eq), 50001)), warning |
|
|
return float_val, warning |
|
|
except Exception as e: |
|
|
logger.warning(f"Error: Custom parsing error {e}, {string_no_eq}") |
|
|
warning = max(warning, WarningType.MAJOR) |
|
|
return None, warning |
|
|
|
|
|
return string, warning |
|
|
|
|
|
@classmethod |
|
|
def is_at_start(cls, string): |
|
|
return True |
|
|
|
|
|
class ParseList(ParseObject): |
|
|
@classmethod |
|
|
def is_at_start(cls, string): |
|
|
return string.startswith(r"(") |
|
|
|
|
|
@classmethod |
|
|
def is_finished(cls, string): |
|
|
|
|
|
return string.strip().strip(",").endswith(")") |
|
|
|
|
|
@classmethod |
|
|
def is_complete(cls, string): |
|
|
return string.count("(") == string.count(")") |
|
|
|
|
|
@classmethod |
|
|
def never_zero_count(cls, string): |
|
|
|
|
|
count = 0 |
|
|
ever_zero = False |
|
|
for char in string: |
|
|
if char == "(": |
|
|
count += 1 |
|
|
if char == ")": |
|
|
count -= 1 |
|
|
if count == 0: |
|
|
ever_zero = True |
|
|
return not ever_zero |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def parse(cls, string, delimiter=[r"\n", ","], primitive_type=None, depth=0): |
|
|
if isinstance(delimiter, str): |
|
|
delimiter = [delimiter] |
|
|
output = [] |
|
|
if not string.startswith("("): |
|
|
return None |
|
|
string = string.strip().strip(",") |
|
|
if cls.never_zero_count(string[:-1]): |
|
|
string = string[1:-1] |
|
|
string = strip(string) |
|
|
used_delim = delimiter[0] |
|
|
for delim in delimiter: |
|
|
if delim in string: |
|
|
comma_separated = string.split(delim) |
|
|
used_delim = delim |
|
|
break |
|
|
warning = WarningType.NONE |
|
|
while len(string) > 0: |
|
|
previous_string = string |
|
|
comma_separated = string.split(used_delim) |
|
|
allowed_objects = [ParseList, ParsePrimitive] |
|
|
if depth > 50: |
|
|
allowed_objects = [ParsePrimitive] |
|
|
for obj in allowed_objects: |
|
|
if obj.is_at_start(strip(string)): |
|
|
current_index = 1 |
|
|
while not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \ |
|
|
not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))): |
|
|
current_index += 1 |
|
|
if current_index >= len(comma_separated): |
|
|
break |
|
|
if not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \ |
|
|
not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))): |
|
|
continue |
|
|
|
|
|
if obj == ParseList: |
|
|
parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), |
|
|
primitive_type=primitive_type, depth=depth+1) |
|
|
else: |
|
|
parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), |
|
|
primitive_type=primitive_type) |
|
|
warning = max(warning, new_warning) |
|
|
output.append(parsed) |
|
|
string = strip(used_delim.join(comma_separated[current_index:])) |
|
|
break |
|
|
if previous_string == string: |
|
|
if depth > 50: |
|
|
logger.error(f"Response {string} reached depth > 50") |
|
|
raise ValueError(f"Failed to parse '{string}'") |
|
|
return None, WarningType.MAJOR |
|
|
return output, warning |