| |
|
|
| from math import isclose |
|
|
| import regex |
| from sympy import N, simplify |
| from sympy.parsing.latex import parse_latex |
| from sympy.parsing.sympy_parser import parse_expr |
|
|
|
|
| def parse_digits(num): |
| |
| num = regex.sub(",", "", str(num)) |
| try: |
| return float(num) |
| except: |
| if num.endswith("%"): |
| num = num[:-1] |
| if num.endswith("\\"): |
| num = num[:-1] |
| try: |
| return float(num) / 100 |
| except: |
| pass |
| return None |
|
|
|
|
| def is_digit(num): |
| |
| return parse_digits(num) is not None |
|
|
|
|
| def symbolic_equal(a, b): |
| def _parse(s): |
| for f in [parse_latex, parse_expr]: |
| try: |
| return f(s) |
| except: |
| pass |
| return s |
|
|
| a = _parse(a) |
| b = _parse(b) |
|
|
| try: |
| if simplify(a - b) == 0: |
| return True |
| except: |
| pass |
|
|
| try: |
| if isclose(N(a), N(b), abs_tol=1e-3): |
| return True |
| except: |
| pass |
| return False |
|
|
|
|
| def math_equal(prediction, reference, include_percentage=True, is_close=True): |
| """ |
| Exact match of math if and only if: |
| 1. numerical equal: both can convert to float and are equal |
| 2. symbolic equal: both can convert to sympy expression and are equal |
| """ |
| if str(prediction) == str(reference): |
| return True |
|
|
| try: |
| if is_digit(prediction) and is_digit(reference): |
| prediction = parse_digits(prediction) |
| reference = parse_digits(reference) |
| |
| if include_percentage: |
| gt_result = [reference / 100, reference, reference * 100] |
| else: |
| gt_result = [reference] |
| for item in gt_result: |
| try: |
| if is_close: |
| if isclose(item, prediction, abs_tol=1e-3): |
| return True |
| else: |
| if item == prediction: |
| return True |
| except Exception: |
| continue |
| return False |
| except: |
| pass |
|
|
| if not prediction and prediction not in [0, False]: |
| return False |
|
|
| |
| reference = str(reference).strip() |
| prediction = str(prediction).strip() |
|
|
| if ( |
| regex.match(r"(\(|\[).+(\)|\])", prediction) is not None |
| and regex.match(r"(\(|\[).+(\)|\])", reference) is not None |
| ): |
| pred_parts = prediction[1:-1].split(",") |
| ref_parts = reference[1:-1].split(",") |
| if len(pred_parts) == len(ref_parts): |
| if all( |
| [ |
| math_equal( |
| pred_parts[i], ref_parts[i], include_percentage, is_close |
| ) |
| for i in range(len(pred_parts)) |
| ] |
| ): |
| return True |
|
|
| |
| if ( |
| ( |
| prediction.startswith("\\begin{pmatrix}") |
| or prediction.startswith("\\begin{bmatrix}") |
| ) |
| and ( |
| prediction.endswith("\\end{pmatrix}") |
| or prediction.endswith("\\end{bmatrix}") |
| ) |
| and ( |
| reference.startswith("\\begin{pmatrix}") |
| or reference.startswith("\\begin{bmatrix}") |
| ) |
| and ( |
| reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") |
| ) |
| ): |
| pred_lines = [ |
| line.strip() |
| for line in prediction[ |
| len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| ].split("\\\\") |
| if line.strip() |
| ] |
| ref_lines = [ |
| line.strip() |
| for line in reference[ |
| len("\\begin{pmatrix}") : -len("\\end{pmatrix}") |
| ].split("\\\\") |
| if line.strip() |
| ] |
| matched = True |
| if len(pred_lines) == len(ref_lines): |
| for pred_line, ref_line in zip(pred_lines, ref_lines): |
| pred_parts = pred_line.split("&") |
| ref_parts = ref_line.split("&") |
| if len(pred_parts) == len(ref_parts): |
| if not all( |
| [ |
| math_equal( |
| pred_parts[i], |
| ref_parts[i], |
| include_percentage, |
| is_close, |
| ) |
| for i in range(len(pred_parts)) |
| ] |
| ): |
| matched = False |
| break |
| else: |
| matched = False |
| if not matched: |
| break |
| else: |
| matched = False |
| if matched: |
| return True |
|
|
| |
| if prediction.count("=") == 1 and reference.count("=") == 1: |
| pred = prediction.split("=") |
| pred = f"{pred[0].strip()} - ({pred[1].strip()})" |
| ref = reference.split("=") |
| ref = f"{ref[0].strip()} - ({ref[1].strip()})" |
| if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): |
| return True |
| elif ( |
| prediction.count("=") == 1 |
| and len(prediction.split("=")[0].strip()) <= 2 |
| and "=" not in reference |
| ): |
| if math_equal( |
| prediction.split("=")[1], reference, include_percentage, is_close |
| ): |
| return True |
| elif ( |
| reference.count("=") == 1 |
| and len(reference.split("=")[0].strip()) <= 2 |
| and "=" not in prediction |
| ): |
| if math_equal( |
| prediction, reference.split("=")[1], include_percentage, is_close |
| ): |
| return True |
|
|
| |
| if symbolic_equal(prediction, reference): |
| return True |
|
|
| return False |
|
|