import regex from typing import Optional from loguru import logger import re from fractions import Fraction from matharena.utils import latex2sympy_fixed from sympy import N, Integer import sympy from typing import Any from enum import Enum from functools import total_ordering from matharena.parse_manual import manual_mapper @total_ordering class WarningType(Enum): NONE = 0 MINOR = 1 POSSIBLE = 2 MAJOR = 3 def __lt__(self, other): if self.__class__ is other.__class__: return self.value < other.value return self.value < other def parse_grading(text: str): pattern = re.compile(r"""Category:\s*(.*?)(?:\s|\n)*Points\s+awarded:\s*(.*?)(?:\s|\n)*Description:\s*(.*?)(?:\s|\n)*(?=Category:|$)""", re.DOTALL | re.VERBOSE) matches = pattern.findall(text) result = {"points": sum(int(points) for _, points, _ in matches), "details": []} for title, points, desc in matches: result["details"].append({ "title": title, "points": int(points), "desc": desc }) return result def find_last_boxed_content(text: str, list_answer: bool = False) -> Optional[str]: pattern = r"(boxed|fbox)\{((?:[^{}]|\{(?2)\})*)\}" matches = list(regex.finditer(pattern, text)) if not matches: return None, WarningType.NONE if len(matches) > 1 and list_answer: # find all boxed content on the same line (no \n in between) as the last boxed split_text = text.split("\n") for i in range(len(split_text)-1, -1, -1): matches_line = list(regex.finditer(pattern, split_text[i])) if len(matches_line) > 0: returned_boxed = ",".join([match.group(2) for match in matches_line]) return returned_boxed, WarningType.POSSIBLE last_match = matches[-1] return last_match.group(2), WarningType.NONE def extract_boxed_answer(text: str, list_answer: bool = False) -> Optional[str]: answer, warning = find_last_boxed_content(text, list_answer) if answer is not None and "=" in answer: answer = answer.split("=")[-1] if answer is not None: return answer, warning else: return None, warning def extract_boxed_answer_parse(text: str, parse: bool = True, list_answer: bool = False) -> Optional[int]: answer, warning = extract_boxed_answer(text, list_answer) if answer is not None: try: return sympy.Integer(int(answer)), warning except: # logger.info(f"Could not parse answer {answer} as integer") if parse: parsed_answer, warning = parse_answer(answer) return parsed_answer, warning return answer, warning return None, WarningType.MAJOR def extract_last_integer(text: str) -> Optional[int]: pattern = r"\b\d+\b" matches = list(regex.finditer(pattern, text)) if not matches: return None, WarningType.MAJOR try: return int(matches[-1].group()), WarningType.MAJOR except: return None, WarningType.MAJOR def extract_answer(text: str, strict_parsing: bool = True, parse: bool = True, list_answer: bool = False): text, warning = replace_unicode(text) answer, warning_new = extract_boxed_answer_parse(text, parse, list_answer) warning = max(warning, warning_new) if answer is not None or strict_parsing: return answer, warning return extract_last_integer(text) def parse_answer(s: str, primitive_type: type = None): warning = WarningType.NONE if s in manual_mapper: logger.warning(f"Applying manual parsing to {s}") s = manual_mapper[s] warning = WarningType.MAJOR s = remove_invalid_characters(s) s = remove_outer_brackets(normalize_string(s)) output, warning_new = ParseList.parse("(" + s + ")", primitive_type=primitive_type) warning = max(warning, warning_new) if output is None: logger.warning(f"Could not parse {s}, returning None") return None, max(warning, WarningType.MAJOR) if len(output) == 1: output = output[0] if isinstance(output, list) or isinstance(output, tuple): output = AnswerList(output) return output, warning def normalize_string(s): s = s.replace(r"\left", "").replace(r"\right", "") s = s.replace(r"\Bigl", "").replace(r"\Bigr", "") s = s.replace(r"\bigl", "").replace(r"\bigr", "") s = remove_aligns(s) s = s.replace("[", "(") s = s.replace("]", ")") s = s.replace("\\{", "(") # sets will be converted to lists s = s.replace("\\}", ")") # sets will be converted to lists s = s.replace("$", "") # remove hline and vline s = s.replace(r"\hline", "") s = s.replace(r"\vline", "") return strip(s) def remove_outer_brackets(s): """ Removes the outermost matching brackets from the string if they encompass the entire string. Parameters: s (str): The input string potentially wrapped with brackets. Returns: str: The string with the outermost brackets removed if they match and encompass the entire string. """ while True: if not s: return s opening = s[0] closing = s[-1] if opening == "(" and closing == ")": count = 0 matched = True for i, char in enumerate(s): if char == opening: count += 1 elif char == closing: count -= 1 if count == 0 and i != len(s) - 1: matched = False break if matched: s = s[1:-1] continue break return s def remove_aligns(s: str) -> str: # This pattern captures: # \begin{align followed by any non-} characters (like align*, alignat, etc.) # then any content (non-greedily) up to # \end{align...} with the same "align" prefix pattern = r'\\begin{align[^}]*}(.*?)\\end{align[^}]*}' # Use a callback to remove '&' from the matched group before returning it return re.sub(pattern, lambda m: m.group(1).replace('&', '').replace("\\\\", ""), s, flags=re.DOTALL) def replace_unicode(text: str) -> str: text_old = text text = text.replace("\u23a7", r"\boxed{") text = text.replace("\u23ab", r"}") text = text.replace("\n\u2502", r"\boxed{") text = text.replace("\u2502", r"}") text = text.replace("\n\u2503", r"\boxed{") text = text.replace("\u2503", r"}") text = text.replace("\n\uf8f0", r"\boxed{") text = text.replace("\uf8fb", r"}") warning = WarningType.NONE if text == text_old else WarningType.POSSIBLE text = text.replace("\u221a", r"\sqrt") # these ones are for sure fine, no warning necessary text = text.replace("\u00d7", r"\cdot") text = text.replace("\u202f", r" ") return text, warning def remove_invalid_characters(text): text = re.sub(r'\\;', '', text) text = re.sub(r'\\,', '', text) return text def strip(s: str): s = s.strip() # be careful with this, it can also remove the "\" in "\begin" if just done with strip while s.startswith(r"\n"): s = s[2:] while s.endswith(r"\n"): s = s[:-2] while s.startswith("\\ "): s = s[2:] # if s starts with any thing of the form \\\ and then a bracket, or \\\n and then a bracket, remove it while re.match(r"\\{2,}\n?\(", s): s = s[3:] return s def check_answers(ans1, ans2): if ans1 is None or ans2 is None: return False if (type(ans1) in [list, AnswerList]) != (type(ans2) in [list, AnswerList]): return False try: if not (hasattr(ans1, 'equals') and callable(ans1.equals)) \ or not (hasattr(ans2, 'equals') and callable(ans2.equals)): # do approximate equal here if isinstance(ans1, str) or isinstance(ans2, str): return ans1 == ans2 if abs(ans1 - ans2) < 10 ** -10: return True return False return ans1.equals(ans2) except Exception as e: logger.error(f'Could not establish equality for answers {ans1} and {ans2}, error: {e}') return False class AnswerList: def __init__(self, answers: list[Any]): if not isinstance(answers, list) and not isinstance(answers, tuple): raise ValueError(f"Expected passed answers to be list or tuple, received {type(answers)}") valid_answers = [] for answer in answers: if bool(re.search(r'\d', str(answer))): valid_answers.append(answer) else: logger.warning(f'Could not find any numbers in {answer}, removed from list') self.answers = list(valid_answers) def equals(self, other: list[Any]): if len(self.answers) != len(other): # logger.info(f"Lists {self.answers} and {other} do not have the same length.") return False match_ids = set() for ans1 in self.answers: match_found = False for i, ans2 in enumerate(other): if i not in match_ids and check_answers(ans1, ans2): match_ids.add(i) match_found = True break if not match_found: # logger.info(f"Could not find a match for element {ans1} in {other}") return False return True def __str__(self): return '[' + ','.join([str(ans) for ans in self.answers]) + ']' def __len__(self): return len(self.answers) def __iter__(self): return iter(self.answers) class ParseObject: @classmethod def is_at_start(cls, string): return False @classmethod def is_complete(cls, string): return string.count("{") == string.count("}") and string.count("(") == string.count(")") @classmethod def is_finished(cls, string): return True @classmethod def parse(cls, string): raise NotImplementedError class ParsePrimitive(ParseObject): @classmethod def parse(cls, string, primitive_type): warning = WarningType.NONE # Integer if string.isdigit(): if primitive_type == Fraction: return Fraction(int(string), 1) return int(string), warning # Float try: float_string = float(string) if int(float_string) == float_string: if primitive_type == Fraction: return Fraction(int(float_string), 1) return int(float_string), warning return float_string, warning except ValueError: # logger.info(f"Couldn't configure floating point to fraction for {string}") pass # Expression if bool(re.search(r'sqrt(\d+)', string)): string = re.sub(r'sqrt(\d+)', r'sqrt{\1}', string) try: latex_str = string for _ in range(5): init_str = latex_str latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str) latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str) latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str) latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str) latex_str = latex_str.replace('^', '**') latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*') latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I') if init_str == latex_str: break for _ in range(5): init_str = latex_str latex_str = re.sub(r'\{(\d+)\}', r'(\1)', latex_str) latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str) latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str) latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str) latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str) latex_str = latex_str.replace('^', '**') latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*') latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I') if init_str == latex_str: break # Handle implcit multiplication latex_str = re.sub(r'(\d|(? 0: previous_string = string comma_separated = string.split(used_delim) allowed_objects = [ParseList, ParsePrimitive] if depth > 50: allowed_objects = [ParsePrimitive] for obj in allowed_objects: if obj.is_at_start(strip(string)): current_index = 1 while not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \ not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))): current_index += 1 if current_index >= len(comma_separated): break if not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \ not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))): continue if obj == ParseList: parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), primitive_type=primitive_type, depth=depth+1) else: parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), primitive_type=primitive_type) warning = max(warning, new_warning) output.append(parsed) string = strip(used_delim.join(comma_separated[current_index:])) break if previous_string == string: if depth > 50: logger.error(f"Response {string} reached depth > 50") raise ValueError(f"Failed to parse '{string}'") return None, WarningType.MAJOR return output, warning