import pandas as pd from tqdm import tqdm import logging import re from word2number import w2n from dataflow.utils.utils import get_logger class StringProcessor: """ A class that encapsulates various string processing functions for mathematical expressions. """ @staticmethod def _fix_fracs(string): """ Fixes fraction expressions in the string, ensuring they are properly formatted as \frac{a}{b}. """ substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: for substr in substrs[1:]: new_str += "\\frac" if len(substr) > 0 and substr[0] == "{": new_str += substr else: if len(substr) >= 2: a, b = substr[0], substr[1] if b != "{": new_str += f"{{{a}}}{{{b}}}{substr[2:]}" if len(substr) > 2 else f"{{{a}}}{{{b}}}" else: new_str += f"{{{a}}}{b}{substr[2:]}" if len(substr) > 2 else f"{{{a}}}{b}" else: return string return new_str @staticmethod def _fix_a_slash_b(string): """ Fixes cases where a fraction is represented as a simple division (e.g., a/b) and converts it to \frac{a}{b}. """ if len(string.split("/")) != 2: return string a, b = string.split("/") try: a, b = int(a) if "sqrt" not in a else a, int(b) if "sqrt" not in b else b assert string == f"{a}/{b}" return f"\\frac{{{a}}}{{{b}}}" except: return string @staticmethod def _fix_sqrt(string): """ Ensures that square root expressions are properly formatted as \sqrt{...}. """ return re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) @staticmethod def convert_word_number(text: str) -> str: """ Converts a word representation of a number to a digit. """ try: return str(w2n.word_to_num(text)) except: return text # Unit Text Class to Manage Unit Texts class UnitTextManager: """ A class that encapsulates unit text management to remove unwanted unit terms from strings. """ def __init__(self): """ Initializes the unit texts and their plural forms. """ self.unit_texts = [ "east", "degree", "mph", "kmph", "ft", "m sqaure", "m east", "sq m", "deg", "mile", "q .", "monkey", "prime", "ratio", "profit of rs", "rd", "o", "gm", "p . m", "lb", "tile", "per", "dm", "lt", "gain", "ab", "way", "west", "a .", "b .", "c .", "d .", "e .", "f .", "g .", "h .", "t", "a", "h", "no change", "men", "soldier", "pie", "bc", "excess", "st", "inches", "noon", "percent", "by", "gal", "kmh", "c", "acre", "rise", "a . m", "th", "π r 2", "sq", "mark", "l", "toy", "coin", "sq . m", "gallon", "° f", "profit", "minw", "yr", "women", "feet", "am", "pm", "hr", "cu cm", "square", "v â € ™", "are", "rupee", "rounds", "cubic", "cc", "mtr", "s", "ohm", "number", "kmph", "day", "hour", "minute", "min", "second", "man", "woman", "sec", "cube", "mt", "sq inch", "mp", "∏ cm ³", "hectare", "more", "sec", "unit", "cu . m", "cm 2", "rs .", "rs", "kg", "g", "month", "km", "m", "cm", "mm", "apple", "liter", "loss", "yard", "pure", "year", "increase", "decrease", "d", "less", "Surface", "litre", "pi sq m", "s .", "metre", "meter", "inch", ] self.unit_texts.extend([t + "s" for t in self.unit_texts]) def clean_units(self, string: str): """ Cleans the string by removing unit terms from it. """ for unit_text in self.unit_texts: string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) return string # Main String Processing Class class StringCleaner: """ A class responsible for cleaning and formatting strings in mathematical expressions. """ def __init__(self, unit_manager: UnitTextManager): """ Initializes the StringCleaner class with a unit manager. """ self.unit_manager = unit_manager def strip_string(self, string, skip_unit=False): """ Strips unwanted characters and units from the string. """ string = str(string).strip().replace("\n", "").rstrip(".").replace("\\!", "") string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string).replace("bmatrix", "pmatrix") string = string.replace("tfrac", "frac").replace("dfrac", "frac").replace("\\neq", "\\ne").replace("\\leq", "\\le").replace("\\geq", "\\ge") string = string.replace("\\left", "").replace("\\right", "").replace("\\{", "{").replace("\\}", "}") # Clean unit texts if needed if not skip_unit: string = self.unit_manager.clean_units(string) string = string.replace("^{\\circ}", "").replace("^\\circ", "").replace("\\$", "").replace("$", "").replace("\\(", "").replace("\\)", "") string = StringProcessor.convert_word_number(string) string = re.sub(r"\\text\{(.*?)\}", r"\1", string) for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: string = string.replace(key, "") string = string.replace("\\emptyset", r"{}").replace("(-\\infty,\\infty)", "\\mathbb{R}") string = string.replace("%", "").replace(" .", " 0.").replace("{.", "{0.") return string # Core Answer Extraction Logic Class class AnswerExtractor: """ A class responsible for extracting the final answer from a prediction string. """ def __init__(self, string_cleaner: StringCleaner): """ Initializes the AnswerExtractor class with a string cleaner. """ self.string_cleaner = string_cleaner self.logger = get_logger() def extract_answer(self, pred_str, data_name, use_last_number=True): """ Extracts the final answer from the prediction string, processing various formats. """ if not pred_str: pred_str = "" pred_str = str(pred_str).replace("\u043a\u0438", "") # Handle special cases based on data_name or pattern if "final answer is $" in pred_str and "$. I hope" in pred_str: pred = pred_str.split("final answer is $", 1)[1].split("$. I hope", 1)[0].strip() elif "boxed" in pred_str: pred = self._extract_boxed_answer(pred_str) elif "he answer is" in pred_str: pred = pred_str.split("he answer is")[-1].strip() else: pred = self._get_last_number_answer(pred_str, use_last_number) pred = self.string_cleaner.strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) return pred def _extract_boxed_answer(self, pred_str): """ Extracts answers enclosed in 'boxed' notation. """ ans = pred_str.split("boxed")[-1] if ans.startswith("{"): return self._extract_bracketed_answer(ans) else: return ans.split("$")[0].strip() def _extract_bracketed_answer(self, ans): """ Handles answers that are enclosed within brackets. """ stack = 1 result = "" for c in ans[1:]: if c == "{": stack += 1 result += c elif c == "}": stack -= 1 if stack == 0: break result += c else: result += c return result def _get_last_number_answer(self, pred_str, use_last_number): """ Extracts the last number from the string if use_last_number is True. """ if use_last_number: pattern = "-?\d*\.?\d+" pred = re.findall(pattern, pred_str.replace(",", "")) return pred[-1] if pred else "" return ""