File size: 18,104 Bytes
0828e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import regex
from typing import Optional
from loguru import logger
import re
from fractions import Fraction
from matharena.utils import latex2sympy_fixed
from sympy import N, Integer
import sympy
from typing import Any
from enum import Enum
from functools import total_ordering
from matharena.parse_manual import manual_mapper

@total_ordering
class WarningType(Enum):
    NONE = 0
    MINOR = 1
    POSSIBLE = 2
    MAJOR = 3
    def __lt__(self, other):
        if self.__class__ is other.__class__:
            return self.value < other.value
        return self.value < other


def parse_grading(text: str):
    pattern = re.compile(r"""Category:\s*(.*?)(?:\s|\n)*Points\s+awarded:\s*(.*?)(?:\s|\n)*Description:\s*(.*?)(?:\s|\n)*(?=Category:|$)""", re.DOTALL | re.VERBOSE)    
    matches = pattern.findall(text)
    
    result = {"points": sum(int(points) for _, points, _ in matches), "details": []}
    
    for title, points, desc in matches:
        result["details"].append({
            "title": title,
            "points": int(points),
            "desc": desc
        })
    
    return result

def find_last_boxed_content(text: str, list_answer: bool = False) -> Optional[str]:
    pattern = r"(boxed|fbox)\{((?:[^{}]|\{(?2)\})*)\}"
    matches = list(regex.finditer(pattern, text))
    if not matches:
        return None, WarningType.NONE
    
    if len(matches) > 1 and list_answer:
        # find all boxed content on the same line (no \n in between) as the last boxed
        split_text = text.split("\n")
        for i in range(len(split_text)-1, -1, -1):
            matches_line = list(regex.finditer(pattern, split_text[i]))
            if len(matches_line) > 0:
                returned_boxed = ",".join([match.group(2) for match in matches_line])
                return returned_boxed, WarningType.POSSIBLE

    last_match = matches[-1]
    return last_match.group(2), WarningType.NONE


def extract_boxed_answer(text: str, list_answer: bool = False) -> Optional[str]:
    answer, warning = find_last_boxed_content(text, list_answer)
    if answer is not None and "=" in answer:
        answer = answer.split("=")[-1]
    if answer is not None:
        return answer, warning
    else:
        return None, warning


def extract_boxed_answer_parse(text: str, parse: bool = True, list_answer: bool = False) -> Optional[int]:
    answer, warning = extract_boxed_answer(text, list_answer)
    if answer is not None:
        try:
            return sympy.Integer(int(answer)), warning
        except:
            # logger.info(f"Could not parse answer {answer} as integer")
            if parse:
                parsed_answer, warning = parse_answer(answer)
                return parsed_answer, warning
            return answer, warning
    return None, WarningType.MAJOR

def extract_last_integer(text: str) -> Optional[int]:
    pattern = r"\b\d+\b"
    matches = list(regex.finditer(pattern, text))
    if not matches:
        return None, WarningType.MAJOR
    try:
        return int(matches[-1].group()), WarningType.MAJOR
    except:
        return None, WarningType.MAJOR


def extract_answer(text: str, strict_parsing: bool = True, parse: bool = True, list_answer: bool = False):
    text, warning = replace_unicode(text)
    answer, warning_new = extract_boxed_answer_parse(text, parse, list_answer)
    warning = max(warning, warning_new)
    if answer is not None or strict_parsing:
        return answer, warning
    
    return extract_last_integer(text)

def parse_answer(s: str, primitive_type: type = None):
    warning = WarningType.NONE
    if s in manual_mapper:
        logger.warning(f"Applying manual parsing to {s}")
        s = manual_mapper[s]
        warning = WarningType.MAJOR
    s = remove_invalid_characters(s)
    s = remove_outer_brackets(normalize_string(s))
    output, warning_new = ParseList.parse("(" + s + ")", primitive_type=primitive_type)
    warning = max(warning, warning_new)
    if output is None:
        logger.warning(f"Could not parse {s}, returning None")
        return None, max(warning, WarningType.MAJOR)
    if len(output) == 1:
        output = output[0]

    if isinstance(output, list) or isinstance(output, tuple):
        output = AnswerList(output)
    return output, warning

def normalize_string(s):
    s = s.replace(r"\left", "").replace(r"\right", "")
    s = s.replace(r"\Bigl", "").replace(r"\Bigr", "")
    s = s.replace(r"\bigl", "").replace(r"\bigr", "")
    s = remove_aligns(s)
    s = s.replace("[", "(")
    s = s.replace("]", ")")
    s = s.replace("\\{", "(") # sets will be converted to lists
    s = s.replace("\\}", ")") # sets will be converted to lists
    s = s.replace("$", "")
    # remove hline and vline
    s = s.replace(r"\hline", "")
    s = s.replace(r"\vline", "")
    return strip(s)

def remove_outer_brackets(s):
    """
    Removes the outermost matching brackets from the string if they encompass the entire string.
    
    Parameters:
    s (str): The input string potentially wrapped with brackets.
    
    Returns:
    str: The string with the outermost brackets removed if they match and encompass the entire string.
    """
    while True:
        if not s:
            return s
        opening = s[0]
        closing = s[-1]

        if opening == "(" and closing == ")":
            count = 0
            matched = True
            for i, char in enumerate(s):
                if char == opening:
                    count += 1
                elif char == closing:
                    count -= 1
                if count == 0 and i != len(s) - 1:
                    matched = False
                    break

            if matched:
                s = s[1:-1]
                continue
        break

    return s

def remove_aligns(s: str) -> str:
    # This pattern captures:
    #   \begin{align followed by any non-} characters (like align*, alignat, etc.)
    #   then any content (non-greedily) up to
    #   \end{align...} with the same "align" prefix
    pattern = r'\\begin{align[^}]*}(.*?)\\end{align[^}]*}'
    
    # Use a callback to remove '&' from the matched group before returning it
    return re.sub(pattern, lambda m: m.group(1).replace('&', '').replace("\\\\", ""), 
                  s, flags=re.DOTALL)

def replace_unicode(text: str) -> str:
    text_old = text
    text = text.replace("\u23a7", r"\boxed{")
    text = text.replace("\u23ab", r"}")
    text = text.replace("\n\u2502", r"\boxed{")
    text = text.replace("\u2502", r"}")
    text = text.replace("\n\u2503", r"\boxed{")
    text = text.replace("\u2503", r"}")
    text = text.replace("\n\uf8f0", r"\boxed{")
    text = text.replace("\uf8fb", r"}")
    warning = WarningType.NONE if text == text_old else WarningType.POSSIBLE
    text = text.replace("\u221a", r"\sqrt") # these ones are for sure fine, no warning necessary
    text = text.replace("\u00d7", r"\cdot")
    text = text.replace("\u202f", r" ")
    return text, warning

def remove_invalid_characters(text):
    text = re.sub(r'\\;', '', text)
    text = re.sub(r'\\,', '', text)
    
    return text

def strip(s: str):
    s = s.strip()
    # be careful with this, it can also remove the "\" in "\begin" if just done with strip
    while s.startswith(r"\n"):
        s = s[2:]
    while s.endswith(r"\n"):
        s = s[:-2]
    while s.startswith("\\ "):
        s = s[2:]
    # if s starts with any thing of the form \\\ and then a bracket, or \\\n and then a bracket, remove it
    while re.match(r"\\{2,}\n?\(", s):
        s = s[3:]
    return s

def check_answers(ans1, ans2):
    if ans1 is None or ans2 is None:
        return False
    if (type(ans1) in [list, AnswerList]) != (type(ans2) in [list, AnswerList]):
        return False
    try:
        if not (hasattr(ans1, 'equals') and callable(ans1.equals)) \
            or not (hasattr(ans2, 'equals') and callable(ans2.equals)):
            # do approximate equal here
            if isinstance(ans1, str) or isinstance(ans2, str):
                return ans1 == ans2
            if abs(ans1 - ans2) < 10 ** -10:
                return True 
            return False
        return ans1.equals(ans2)
    except Exception as e:
        logger.error(f'Could not establish equality for answers {ans1} and {ans2}, error: {e}')
        return False

class AnswerList:
    def __init__(self, answers: list[Any]):
        if not isinstance(answers, list) and not isinstance(answers, tuple):
            raise ValueError(f"Expected passed answers to be list or tuple, received {type(answers)}")
        
        valid_answers = []
        for answer in answers:
            if bool(re.search(r'\d', str(answer))):
                valid_answers.append(answer)
            else:
                logger.warning(f'Could not find any numbers in {answer}, removed from list')

        self.answers = list(valid_answers)

    def equals(self, other: list[Any]):
        if len(self.answers) != len(other):
            # logger.info(f"Lists {self.answers} and {other} do not have the same length.")
            return False
        
        match_ids = set()
        for ans1 in self.answers:
            match_found = False
            for i, ans2 in enumerate(other):
                if i not in match_ids and check_answers(ans1, ans2):
                    match_ids.add(i)
                    match_found = True
                    break
            if not match_found:
                # logger.info(f"Could not find a match for element {ans1} in {other}")
                return False
        return True
    
    def __str__(self):
        return '[' + ','.join([str(ans) for ans in self.answers]) + ']'
    
    def __len__(self):
        return len(self.answers)
    
    def __iter__(self):
        return iter(self.answers)

class ParseObject:
    @classmethod
    def is_at_start(cls, string):
        return False

    @classmethod
    def is_complete(cls, string):
        return string.count("{") == string.count("}") and string.count("(") == string.count(")")

    @classmethod
    def is_finished(cls, string):
        return True
    
    @classmethod
    def parse(cls, string):
        raise NotImplementedError
    
class ParsePrimitive(ParseObject):
    @classmethod
    def parse(cls, string, primitive_type):
        warning = WarningType.NONE
        # Integer
        if string.isdigit():
            if primitive_type == Fraction:
                return Fraction(int(string), 1)
            return int(string), warning
        # Float
        try:
            float_string = float(string)
            if int(float_string) == float_string:
                if primitive_type == Fraction:
                    return Fraction(int(float_string), 1)
                return int(float_string), warning
            return float_string, warning
        except ValueError:
            # logger.info(f"Couldn't configure floating point to fraction for {string}")
            pass
        # Expression
        if bool(re.search(r'sqrt(\d+)', string)):
            string = re.sub(r'sqrt(\d+)', r'sqrt{\1}', string)
        try:
            latex_str = string
            for _ in range(5):
                init_str = latex_str
                latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str)
                latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str)
                latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str)
                latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str)
  
                latex_str = latex_str.replace('^', '**')
                latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*')
                latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I')

                if init_str == latex_str:
                    break
            
            for _ in range(5):

                init_str = latex_str
                latex_str = re.sub(r'\{(\d+)\}', r'(\1)', latex_str)
                latex_str = re.sub(r'\\*(?:dfrac|tfrac|frac)\{([^{}]*)\}\{([^{}]*)\}', r'(\1)/(\2)', latex_str)
                latex_str = re.sub(r'\\*binom\{([^{}]*)\}\{([^{}]*)\}', r'binomial(\1, \2)', latex_str)
                latex_str = re.sub(r'\\*sqrt\[(.*?)\]\{(.*?)\}', r'(\2)**(1/(\1))', latex_str)
                latex_str = re.sub(r'\\*sqrt\{(.*?)\}', r'(\1)**(1/2)', latex_str)

                latex_str = latex_str.replace('^', '**')
                latex_str = latex_str.replace('\\cdot', '*').replace('\\times', '*')
                latex_str = latex_str.replace('\\pi', 'pi').replace('\\e', 'E').replace('\\i', 'I')
                if init_str == latex_str:
                    break
            
            # Handle implcit multiplication
            latex_str = re.sub(r'(\d|(?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))\(', r'\1*(', latex_str) 
            latex_str = re.sub(r'\)(\d|(?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))', r')*\1', latex_str)
            latex_str = re.sub(r'(?<=\d)((?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))', r'*\1', latex_str)
            latex_str = re.sub(r'((?<![a-zA-Z])[a-zA-Z]{1,2}(?![a-zA-Z]))(?=\d)', r'\1*', latex_str)
            latex_str = re.sub(r'\{([^{}]*)\}', lambda m: '[' + m.group(1).replace(',', ', ') + ']', 
                               latex_str)

            string = sympy.sympify(latex_str, 
                                   locals={'binomial': sympy.binomial, 
                                           'pi': sympy.pi, 
                                           'E': sympy.E, 
                                           'I': sympy.I}
                                )
        except Exception as e:
            # logger.warning(f"Couldn't parse {string} with standard LaTeX commands")

            try:
                string_no_eq = string
                if "=" in string_no_eq:
                    # rfind is used to remove the last occurence of "="
                    string_no_eq = string_no_eq[string_no_eq.rfind("=")+1:]
                float_val = float(N(latex2sympy_fixed(string_no_eq), 101))
                if float_val.is_integer() or float("inf") == float_val or float("-inf") == float_val:
                    return int(N(latex2sympy_fixed(string_no_eq), 50001)), warning # important for large ints
                return float_val, warning
            except Exception as e:
                logger.warning(f"Error: Custom parsing error {e}, {string_no_eq}")
                warning = max(warning, WarningType.MAJOR)
                return None, warning

        return string, warning
        
    @classmethod
    def is_at_start(cls, string):
        return True

class ParseList(ParseObject):
    @classmethod
    def is_at_start(cls, string):
        return string.startswith(r"(")
    
    @classmethod
    def is_finished(cls, string):
        # safe condition for finishing a list
        return string.strip().strip(",").endswith(")")
    
    @classmethod
    def is_complete(cls, string):
        return string.count("(") == string.count(")")
    
    @classmethod
    def never_zero_count(cls, string):
        # says wheter count "(" - count ")" for every string[:i] is never zero
        count = 0
        ever_zero = False
        for char in string:
            if char == "(":
                count += 1
            if char == ")":
                count -= 1
            if count == 0:
                ever_zero = True
        return not ever_zero

    
    @classmethod
    def parse(cls, string, delimiter=[r"\n", ","], primitive_type=None, depth=0):
        if isinstance(delimiter, str):
            delimiter = [delimiter]
        output = []
        if not string.startswith("("):
            return None
        string = string.strip().strip(",")
        if cls.never_zero_count(string[:-1]):
            string = string[1:-1]
        string = strip(string)
        used_delim = delimiter[0]
        for delim in delimiter:
            if delim in string:
                comma_separated = string.split(delim)
                used_delim = delim
                break
        warning = WarningType.NONE
        while len(string) > 0:
            previous_string = string
            comma_separated = string.split(used_delim)
            allowed_objects = [ParseList, ParsePrimitive]
            if depth > 50:
                allowed_objects = [ParsePrimitive]
            for obj in allowed_objects:
                if obj.is_at_start(strip(string)):
                    current_index = 1
                    while not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \
                        not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))):
                        current_index += 1
                        if current_index >= len(comma_separated):
                            break
                    if not obj.is_complete(strip(used_delim.join(comma_separated[:current_index]))) or \
                        not obj.is_finished(strip(used_delim.join(comma_separated[:current_index]))):
                        continue
                    
                    if obj == ParseList:
                        parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), 
                                            primitive_type=primitive_type, depth=depth+1)
                    else:
                        parsed, new_warning = obj.parse(strip(used_delim.join(comma_separated[:current_index])), 
                                    primitive_type=primitive_type)
                    warning = max(warning, new_warning)
                    output.append(parsed)
                    string = strip(used_delim.join(comma_separated[current_index:]))
                    break
            if previous_string == string:
                if depth > 50:
                    logger.error(f"Response {string} reached depth > 50")
                    raise ValueError(f"Failed to parse '{string}'")
                return None, WarningType.MAJOR
        return output, warning