Upload math_verify/grader.py with huggingface_hub
Browse files- math_verify/grader.py +877 -0
math_verify/grader.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2024 The HuggingFace Team
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# Heavily inspired by https://github.com/QwenLM/Qwen2.5-Math and https://github.com/huggingface/lm-evaluation-harness
|
| 24 |
+
import logging
|
| 25 |
+
import re
|
| 26 |
+
from itertools import product
|
| 27 |
+
|
| 28 |
+
from latex2sympy2_extended import is_expr_of_only_symbols
|
| 29 |
+
from latex2sympy2_extended.logic import And
|
| 30 |
+
from latex2sympy2_extended.sets import FiniteSet
|
| 31 |
+
from sympy import (
|
| 32 |
+
Basic,
|
| 33 |
+
E,
|
| 34 |
+
Eq,
|
| 35 |
+
Float,
|
| 36 |
+
GreaterThan,
|
| 37 |
+
Interval,
|
| 38 |
+
LessThan,
|
| 39 |
+
MatrixBase,
|
| 40 |
+
MatrixExpr,
|
| 41 |
+
Mul,
|
| 42 |
+
Number,
|
| 43 |
+
Rational,
|
| 44 |
+
Set,
|
| 45 |
+
StrictGreaterThan,
|
| 46 |
+
StrictLessThan,
|
| 47 |
+
Symbol,
|
| 48 |
+
Tuple,
|
| 49 |
+
default_sort_key,
|
| 50 |
+
nan,
|
| 51 |
+
ordered,
|
| 52 |
+
simplify,
|
| 53 |
+
solve,
|
| 54 |
+
zoo,
|
| 55 |
+
)
|
| 56 |
+
from sympy import FiniteSet as SympyFiniteSet
|
| 57 |
+
from sympy.core.function import UndefinedFunction
|
| 58 |
+
from sympy.core.relational import Relational
|
| 59 |
+
|
| 60 |
+
from math_verify.errors import TimeoutException
|
| 61 |
+
from math_verify.utils import timeout
|
| 62 |
+
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
TIMEOUT_WARNING_SHOWN = False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
INVERSE_RELATIONS = {
|
| 69 |
+
GreaterThan: LessThan,
|
| 70 |
+
LessThan: GreaterThan,
|
| 71 |
+
StrictGreaterThan: StrictLessThan,
|
| 72 |
+
StrictLessThan: StrictGreaterThan,
|
| 73 |
+
Eq: Eq,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def safe_sympy_doit(a: Basic | MatrixBase):
|
| 78 |
+
"""Safely execute doit() on a sympy expression, catching exceptions.
|
| 79 |
+
Doit in sympy will evaluate expressions it will pass the expression tree and evluate nodes.
|
| 80 |
+
For example for 1+1+1 it will evaluate the additions and return 3. One issue with it is that it maybe
|
| 81 |
+
evaluates too much as integrals will also be evaluated.
|
| 82 |
+
As we are using latex2sympy2_extended, evaluates are lazy and only evaluated when needed.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
a: A sympy Basic or MatrixBase expression to evaluate
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
The result of a.doit() if successful, otherwise returns the original expression
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
return a.doit()
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
return a
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def is_atomic_or_pct_atomic(expr: Basic | MatrixBase, atomic_type: type) -> bool:
|
| 98 |
+
"""Check if expression is either an atomic type or percentage atomic type.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
expr: The sympy expression to check
|
| 102 |
+
atomic_type: The atomic type to check for
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
True if expr is atomic_type or percentage atomic type, False otherwise
|
| 106 |
+
"""
|
| 107 |
+
return isinstance(expr, atomic_type) or (
|
| 108 |
+
# Check for percentage representation: latex2sympy_extended converts "X%" into X*Rational(1,100)
|
| 109 |
+
# So we detect percentages by looking for this multiplication structure
|
| 110 |
+
isinstance(expr, Mul)
|
| 111 |
+
and len(expr.args) == 2
|
| 112 |
+
and expr.args[1] == Rational(1, 100)
|
| 113 |
+
and isinstance(expr.args[0], atomic_type)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sympy_numeric_eq(
|
| 118 |
+
a: Basic | MatrixBase,
|
| 119 |
+
b: Basic | MatrixBase,
|
| 120 |
+
float_rounding: int,
|
| 121 |
+
numeric_precision: int,
|
| 122 |
+
):
|
| 123 |
+
"""Compare two sympy expressions numerically with given precision.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
a: First sympy expression
|
| 127 |
+
b: Second sympy expression
|
| 128 |
+
precision: Number of decimal places to compare
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
True if expressions are numerically equal within precision, False otherwise
|
| 132 |
+
"""
|
| 133 |
+
# Only do this when one of the two is a float, in other cases use symbolic equality as this could lead to false positives
|
| 134 |
+
# E.g we want 1/3 == 0.333333 to work
|
| 135 |
+
if isinstance(a, (MatrixBase, MatrixExpr)) and isinstance(
|
| 136 |
+
b, (MatrixBase, MatrixExpr)
|
| 137 |
+
):
|
| 138 |
+
a = safe_sympy_doit(a)
|
| 139 |
+
b = safe_sympy_doit(b)
|
| 140 |
+
|
| 141 |
+
# If we have matrices and one of them is only made of floats, we can use the same logic as above
|
| 142 |
+
if (
|
| 143 |
+
isinstance(a, (MatrixBase))
|
| 144 |
+
and isinstance(b, (MatrixBase))
|
| 145 |
+
and a.shape == b.shape
|
| 146 |
+
):
|
| 147 |
+
return all(
|
| 148 |
+
sympy_numeric_eq(a_elem, b_elem, float_rounding, numeric_precision)
|
| 149 |
+
for a_elem, b_elem in zip(a.flat(), b.flat(), strict=True)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Ensure this also works for percentage numbers so that 0.333333% = 0.33333333333 with precision 4
|
| 153 |
+
elif is_atomic_or_pct_atomic(a, Number) or is_atomic_or_pct_atomic(b, Number):
|
| 154 |
+
# If one of them is a float or a percentage number, we can try to use float precision
|
| 155 |
+
if is_atomic_or_pct_atomic(a, Float) or is_atomic_or_pct_atomic(b, Float):
|
| 156 |
+
a = safe_sympy_doit(a)
|
| 157 |
+
b = safe_sympy_doit(b)
|
| 158 |
+
try:
|
| 159 |
+
return a.round(float_rounding) == b.round(float_rounding)
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
else:
|
| 163 |
+
return safe_sympy_doit(a) == safe_sympy_doit(b)
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
try:
|
| 167 |
+
return (a - b).evalf(chop=True, n=numeric_precision) == 0 # type: ignore
|
| 168 |
+
except Exception:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
+
return False
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool:
|
| 175 |
+
"""Compare two sympy expressions symbolically.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
a: First sympy expression
|
| 179 |
+
b: Second sympy expression
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
True if expressions are symbolically equal, False otherwise
|
| 183 |
+
"""
|
| 184 |
+
try:
|
| 185 |
+
a_b_diff = simplify((a - b)) # type: ignore
|
| 186 |
+
if isinstance(a_b_diff, MatrixBase) and a_b_diff.is_zero_matrix:
|
| 187 |
+
return True
|
| 188 |
+
elif isinstance(a_b_diff, Basic) and a_b_diff.is_zero:
|
| 189 |
+
return True
|
| 190 |
+
except Exception:
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def unwrap_eq(s):
|
| 197 |
+
if is_assignment_relation(s):
|
| 198 |
+
return take_last_relation(s).rhs
|
| 199 |
+
return s
|
| 200 |
+
|
| 201 |
+
def sort_key(x):
|
| 202 |
+
try:
|
| 203 |
+
return default_sort_key(unwrap_eq(x).evalf())
|
| 204 |
+
except Exception:
|
| 205 |
+
return default_sort_key(unwrap_eq(x))
|
| 206 |
+
|
| 207 |
+
def sympy_deep_compare_set_and_tuple(
|
| 208 |
+
gold: SympyFiniteSet | Tuple,
|
| 209 |
+
pred: SympyFiniteSet | Tuple,
|
| 210 |
+
float_rounding: int,
|
| 211 |
+
numeric_precision: int,
|
| 212 |
+
) -> bool:
|
| 213 |
+
"""Compare two finite sets by comparing each element with given precision.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
a: First finite set
|
| 217 |
+
b: Second finite set
|
| 218 |
+
precision: Number of decimal places to compare
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
True if sets contain equal elements within precision, False otherwise
|
| 222 |
+
|
| 223 |
+
Note: in order to fully support finite sets, we should ideally do kartesian product comparison
|
| 224 |
+
but this is not implemented yet. We kinda hope sympy will order the elements.
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
# This ensures it works for {1/3} and {0.333333}
|
| 228 |
+
if len(gold) == len(pred):
|
| 229 |
+
if isinstance(gold, SympyFiniteSet):
|
| 230 |
+
gold_args = list(ordered(gold.args, keys=sort_key, default=False))
|
| 231 |
+
pred_args = list(ordered(pred.args, keys=sort_key, default=False))
|
| 232 |
+
|
| 233 |
+
elif isinstance(gold, Tuple) and isinstance(pred, FiniteSet):
|
| 234 |
+
# We treat the pred as tuple too
|
| 235 |
+
pred_args = pred._unsorted_args
|
| 236 |
+
gold_args = gold.args
|
| 237 |
+
|
| 238 |
+
elif isinstance(pred, SympyFiniteSet):
|
| 239 |
+
pred_args = list(ordered(pred.args, keys=sort_key, default=False))
|
| 240 |
+
gold_args = gold.args
|
| 241 |
+
else:
|
| 242 |
+
gold_args = gold.args
|
| 243 |
+
pred_args = pred.args
|
| 244 |
+
|
| 245 |
+
return all(
|
| 246 |
+
sympy_expr_eq(a, b, float_rounding, numeric_precision)
|
| 247 |
+
for a, b in zip(gold_args, pred_args, strict=True)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return False
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def sympy_compare_interval(
|
| 254 |
+
a: Interval, b: Interval, float_rounding: int, numeric_precision: int
|
| 255 |
+
) -> bool:
|
| 256 |
+
"""Compare two intervals.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
a: First interval
|
| 260 |
+
b: Second interval
|
| 261 |
+
precision: Number of decimal places to compare endpoints
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
True if intervals are equal, False otherwise
|
| 265 |
+
"""
|
| 266 |
+
return (
|
| 267 |
+
a.left_open == b.left_open
|
| 268 |
+
and a.right_open == b.right_open
|
| 269 |
+
and sympy_expr_eq(a.start, b.start, float_rounding, numeric_precision)
|
| 270 |
+
and sympy_expr_eq(a.end, b.end, float_rounding, numeric_precision)
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def sympy_solve_and_compare(
|
| 275 |
+
gold: Relational, pred: Relational, float_rounding: int, numeric_precision: int
|
| 276 |
+
) -> bool:
|
| 277 |
+
solved_gold = list(ordered(solve(gold, gold.free_symbols)))
|
| 278 |
+
solved_pred = list(ordered(solve(pred, pred.free_symbols)))
|
| 279 |
+
# Equalities should return list of dicts of solutions
|
| 280 |
+
if isinstance(gold, Eq) and isinstance(pred, Eq):
|
| 281 |
+
return all(
|
| 282 |
+
all(
|
| 283 |
+
g_k == p_k
|
| 284 |
+
and sympy_expr_eq(g_v, p_v, float_rounding, numeric_precision)
|
| 285 |
+
for (g_k, g_v), (p_k, p_v) in zip(
|
| 286 |
+
sorted(g.items()), sorted(p.items()), strict=True
|
| 287 |
+
)
|
| 288 |
+
)
|
| 289 |
+
for g, p in zip(ordered(solved_gold, keys=sort_key, default=False), ordered(solved_pred, keys=sort_key, default=False), strict=True)
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
return sympy_expr_eq(
|
| 293 |
+
solved_gold, solved_pred, float_rounding, numeric_precision
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def sympy_compare_relational(
|
| 298 |
+
gold: Relational | And,
|
| 299 |
+
pred: Relational | And,
|
| 300 |
+
float_rounding: int,
|
| 301 |
+
numeric_precision: int,
|
| 302 |
+
) -> bool:
|
| 303 |
+
"""Compare two relational expressions.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
gold: First relational expression
|
| 307 |
+
pred: Second relational expression
|
| 308 |
+
precision: Number of decimal places to compare
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
True if relations are equivalent, False otherwise
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
if isinstance(gold, And) and isinstance(pred, And):
|
| 315 |
+
return all(
|
| 316 |
+
sympy_compare_relational(g, p, float_rounding, numeric_precision)
|
| 317 |
+
for g, p in zip(gold._unsorted_args, pred._unsorted_args, strict=True)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
elif not isinstance(gold, Relational) or not isinstance(pred, Relational):
|
| 321 |
+
return False
|
| 322 |
+
|
| 323 |
+
# Helper to check if expressions are equivalent when flipped
|
| 324 |
+
def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool:
|
| 325 |
+
try:
|
| 326 |
+
return sympy_expr_eq(
|
| 327 |
+
a.lhs - a.rhs, b.rhs - b.lhs, float_rounding, numeric_precision
|
| 328 |
+
) # type: ignore
|
| 329 |
+
except Exception:
|
| 330 |
+
pass
|
| 331 |
+
return False
|
| 332 |
+
|
| 333 |
+
# Same type of relation (e.g. both <= or both >=)
|
| 334 |
+
try:
|
| 335 |
+
if type(gold) is type(pred) and sympy_expr_eq(
|
| 336 |
+
gold.lhs - gold.rhs, pred.lhs - pred.rhs, float_rounding, numeric_precision
|
| 337 |
+
): # type: ignore
|
| 338 |
+
return True
|
| 339 |
+
except Exception:
|
| 340 |
+
pass
|
| 341 |
+
|
| 342 |
+
# Check flipped inequalities (a <= b equals b >= a)
|
| 343 |
+
if INVERSE_RELATIONS[type(gold)] is type(pred) and are_flipped_inequalities_equal( # type: ignore
|
| 344 |
+
gold, pred
|
| 345 |
+
):
|
| 346 |
+
return True
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision):
|
| 350 |
+
return True
|
| 351 |
+
except Exception:
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def sympy_str_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool:
|
| 358 |
+
"""Compare two sympy expressions by string representation.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
a: First sympy expression
|
| 362 |
+
b: Second sympy expression
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
True if string representations are equal, False otherwise
|
| 366 |
+
"""
|
| 367 |
+
# We can't evaluate nan or zoo
|
| 368 |
+
if a == nan or a == zoo:
|
| 369 |
+
raise ValueError("Can't evaluate nan or zoo")
|
| 370 |
+
try:
|
| 371 |
+
return a == b
|
| 372 |
+
except Exception:
|
| 373 |
+
pass
|
| 374 |
+
return False
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def sympy_compare_sets(
|
| 378 |
+
gold: Set | Basic | MatrixBase | Tuple,
|
| 379 |
+
pred: Set | Basic | MatrixBase | Tuple,
|
| 380 |
+
float_rounding: int,
|
| 381 |
+
numeric_precision: int,
|
| 382 |
+
) -> bool:
|
| 383 |
+
"""Compare two sympy sets for equality using multiple methods.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
gold: First sympy set (expected)
|
| 387 |
+
pred: Second sympy set (predicted)
|
| 388 |
+
precision: Number of decimal places to compare
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
True if sets are equal by any comparison method, False otherwise
|
| 392 |
+
"""
|
| 393 |
+
# Convert non-sets to singleton sets
|
| 394 |
+
a_set = gold if isinstance(gold, (Set, Tuple)) else SympyFiniteSet(gold)
|
| 395 |
+
b_set = pred if isinstance(pred, (Set, Tuple)) else SympyFiniteSet(pred)
|
| 396 |
+
|
| 397 |
+
# If both are intervals, use interval comparison
|
| 398 |
+
if isinstance(a_set, Interval) and isinstance(b_set, Interval):
|
| 399 |
+
return sympy_compare_interval(a_set, b_set, float_rounding, numeric_precision)
|
| 400 |
+
|
| 401 |
+
# Try direct set equality
|
| 402 |
+
if a_set == b_set:
|
| 403 |
+
return True
|
| 404 |
+
|
| 405 |
+
# If both are sets, check if they are equal
|
| 406 |
+
try:
|
| 407 |
+
if (
|
| 408 |
+
isinstance(a_set, Set)
|
| 409 |
+
and isinstance(b_set, Set)
|
| 410 |
+
and a_set.symmetric_difference(b_set).is_empty
|
| 411 |
+
):
|
| 412 |
+
return True
|
| 413 |
+
except Exception:
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
# For finite sets, compare elements
|
| 417 |
+
if isinstance(a_set, (SympyFiniteSet, Tuple)) and isinstance(
|
| 418 |
+
b_set, (SympyFiniteSet, Tuple)
|
| 419 |
+
):
|
| 420 |
+
return sympy_deep_compare_set_and_tuple(
|
| 421 |
+
a_set, b_set, float_rounding, numeric_precision
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Because (1,2) is parsed as Interval(1,2,left_open=True,right_open=True), it could have that the
|
| 425 |
+
# correct is (1,2) and predicted is 1,2, which is parsed as Set(1,2)
|
| 426 |
+
if isinstance(a_set, Interval) and isinstance(b_set, (SympyFiniteSet, Tuple)):
|
| 427 |
+
if a_set.is_open and len(b_set) == 2:
|
| 428 |
+
return sympy_deep_compare_set_and_tuple(
|
| 429 |
+
Tuple(a_set.start, a_set.end), b_set, float_rounding, numeric_precision
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if isinstance(b_set, Interval) and isinstance(a_set, (SympyFiniteSet, Tuple)):
|
| 433 |
+
if b_set.is_open and len(a_set) == 2:
|
| 434 |
+
return sympy_deep_compare_set_and_tuple(
|
| 435 |
+
a_set, Tuple(b_set.start, b_set.end), float_rounding, numeric_precision
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
return False
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
|
| 442 |
+
"""Compare two sympy expressions where at least one is a Symbol.
|
| 443 |
+
|
| 444 |
+
Handles special cases:
|
| 445 |
+
- One is Symbol and other is E (limitation of parsed expressions)
|
| 446 |
+
- One is multiplication of symbols and other is single symbol (concatenated comparison)
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
gold: First sympy expression (expected)
|
| 450 |
+
pred: Second sympy expression (predicted)
|
| 451 |
+
precision: Number of decimal places to compare
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
True if expressions are equal by any comparison method, False otherwise
|
| 455 |
+
"""
|
| 456 |
+
# Handle E vs symbol case
|
| 457 |
+
if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
|
| 458 |
+
isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
|
| 459 |
+
):
|
| 460 |
+
return True
|
| 461 |
+
|
| 462 |
+
# Handle multiplication of symbols vs single symbol, because parsing return $abc$ -> abc
|
| 463 |
+
# We also handle E as it's a symbol, because E will be always parsed as exp
|
| 464 |
+
if (
|
| 465 |
+
isinstance(gold, Symbol)
|
| 466 |
+
and isinstance(pred, Mul)
|
| 467 |
+
and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
|
| 468 |
+
):
|
| 469 |
+
concat_pred = "".join(
|
| 470 |
+
arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args
|
| 471 |
+
)
|
| 472 |
+
return gold.name.lower() == concat_pred.lower()
|
| 473 |
+
|
| 474 |
+
if (
|
| 475 |
+
isinstance(pred, Symbol)
|
| 476 |
+
and isinstance(gold, Mul)
|
| 477 |
+
and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
|
| 478 |
+
):
|
| 479 |
+
concat_gold = "".join(
|
| 480 |
+
arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args
|
| 481 |
+
)
|
| 482 |
+
return pred.name.lower() == concat_gold.lower()
|
| 483 |
+
|
| 484 |
+
# Simple
|
| 485 |
+
if isinstance(gold, Symbol) and isinstance(pred, Symbol):
|
| 486 |
+
g_name = gold.name
|
| 487 |
+
p_name = pred.name
|
| 488 |
+
if len(p_name) > 1:
|
| 489 |
+
p_name = p_name.lower()
|
| 490 |
+
if len(g_name) > 1:
|
| 491 |
+
g_name = g_name.lower()
|
| 492 |
+
return g_name == p_name
|
| 493 |
+
|
| 494 |
+
return str(gold) == str(pred)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def is_relation(expr: Basic | MatrixBase) -> bool:
|
| 498 |
+
"""Check if an expression is a relational expression.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
expr: The expression to check
|
| 502 |
+
Returns:
|
| 503 |
+
bool: True if expr is a relational expression or And of relations, False otherwise
|
| 504 |
+
"""
|
| 505 |
+
if isinstance(expr, Relational):
|
| 506 |
+
return True
|
| 507 |
+
|
| 508 |
+
if isinstance(expr, And) and len(expr._unsorted_args) > 0:
|
| 509 |
+
return all(isinstance(arg, Relational) for arg in expr._unsorted_args)
|
| 510 |
+
|
| 511 |
+
return False
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def is_equation(expr: Basic | MatrixBase) -> bool:
|
| 515 |
+
"""Check if an expression is an equation.
|
| 516 |
+
|
| 517 |
+
Args:
|
| 518 |
+
expr: The expression to check
|
| 519 |
+
Returns:
|
| 520 |
+
bool: True if expr is an equation, False otherwise
|
| 521 |
+
"""
|
| 522 |
+
if isinstance(expr, Eq):
|
| 523 |
+
return True
|
| 524 |
+
|
| 525 |
+
if isinstance(expr, And) and len(expr._unsorted_args) > 0:
|
| 526 |
+
return all(isinstance(arg, Eq) for arg in expr._unsorted_args)
|
| 527 |
+
|
| 528 |
+
return False
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
|
| 532 |
+
"""Check if an expression is an assignment relation. E.g a=1
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
expr: The expression to check
|
| 536 |
+
Returns:
|
| 537 |
+
bool: True if expr is a relational expression or And of relations, False otherwise
|
| 538 |
+
"""
|
| 539 |
+
if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs):
|
| 540 |
+
return True
|
| 541 |
+
|
| 542 |
+
if isinstance(expr, And) and len(expr._unsorted_args) > 0:
|
| 543 |
+
return all(
|
| 544 |
+
isinstance(arg, Eq) for arg in expr._unsorted_args
|
| 545 |
+
) and is_expr_of_only_symbols(expr._unsorted_args[0].lhs)
|
| 546 |
+
|
| 547 |
+
return False
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def take_last_relation(expr: And | Relational) -> Relational:
|
| 551 |
+
"""Take the last relation from an And expression."""
|
| 552 |
+
if isinstance(expr, And):
|
| 553 |
+
return take_last_relation(expr._unsorted_args[-1])
|
| 554 |
+
return expr
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def take_first_relation(expr: And | Relational) -> Relational:
|
| 558 |
+
"""Take the first relation from an And expression."""
|
| 559 |
+
if isinstance(expr, And):
|
| 560 |
+
return expr._unsorted_args[0]
|
| 561 |
+
return expr
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def unwrap_fcs(expr: Basic | MatrixBase) -> Basic | MatrixBase:
|
| 565 |
+
"""Unwrap function calls to their arguments.
|
| 566 |
+
|
| 567 |
+
For example, Function('f')(x) becomes Symbol('f_x')
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
expr: The expression to unwrap
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
The unwrapped expression with functions replaced by concatenated symbols
|
| 574 |
+
"""
|
| 575 |
+
# Base case - not a Basic type
|
| 576 |
+
if not isinstance(expr, Basic):
|
| 577 |
+
return expr
|
| 578 |
+
|
| 579 |
+
# Handle function case
|
| 580 |
+
if hasattr(expr, "func") and isinstance(expr.func, UndefinedFunction):
|
| 581 |
+
# Get function name and arguments
|
| 582 |
+
func_name = expr.func.__name__
|
| 583 |
+
# Recursively unwrap arguments before converting to string
|
| 584 |
+
unwrapped_args = [str(unwrap_fcs(arg)) for arg in expr.args]
|
| 585 |
+
# Create new symbol by concatenating function name and args
|
| 586 |
+
return Symbol(f"{func_name}_{'_'.join(unwrapped_args)}")
|
| 587 |
+
|
| 588 |
+
# Recursively unwrap all arguments
|
| 589 |
+
try:
|
| 590 |
+
new_args = [unwrap_fcs(arg) for arg in expr.args]
|
| 591 |
+
if new_args:
|
| 592 |
+
return expr.func(*new_args)
|
| 593 |
+
except Exception:
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
return expr
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def sympy_expr_eq(
|
| 600 |
+
gold: Basic | MatrixBase,
|
| 601 |
+
pred: Basic | MatrixBase,
|
| 602 |
+
float_rounding: int,
|
| 603 |
+
numeric_precision: int,
|
| 604 |
+
allow_set_relation_comp: bool = False,
|
| 605 |
+
strict: bool = True,
|
| 606 |
+
) -> bool:
|
| 607 |
+
"""Compare two sympy expressions for equality using multiple methods.
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
gold: First sympy expression (expected)
|
| 611 |
+
pred: Second sympy expression (predicted)
|
| 612 |
+
precision: Number of decimal places to compare
|
| 613 |
+
allow_set_relation_comp: Whether to allow set - relation comparison. Defaults to False.
|
| 614 |
+
- If True, set - relation comparison will be allowed in all cases.
|
| 615 |
+
- If False, set - relation comparison will be allowed only if the prediction is a set.
|
| 616 |
+
strict: If true, variables do matter otherwise they don't
|
| 617 |
+
|
| 618 |
+
Returns:
|
| 619 |
+
True if expressions are equal by any comparison method, False otherwise
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
# This ensures that f(x) == f(y) is true
|
| 623 |
+
if not strict:
|
| 624 |
+
try:
|
| 625 |
+
gold_variables = gold.free_symbols
|
| 626 |
+
pred_variables = pred.free_symbols
|
| 627 |
+
if len(gold_variables) == len(pred_variables):
|
| 628 |
+
pred = pred.subs(
|
| 629 |
+
list(zip(pred_variables, gold_variables, strict=True))
|
| 630 |
+
)
|
| 631 |
+
except Exception:
|
| 632 |
+
pass
|
| 633 |
+
|
| 634 |
+
# If both are assigments, we don't want to unwrap them, so that x=1 != y=1
|
| 635 |
+
# But if one is assignment and other is equation, we want to unwrap both
|
| 636 |
+
|
| 637 |
+
# We always want to truncate if it's assignment, assignment
|
| 638 |
+
|
| 639 |
+
is_gold_assignment = is_assignment_relation(gold)
|
| 640 |
+
is_pred_assignment = is_assignment_relation(pred)
|
| 641 |
+
is_gold_equation = is_equation(gold)
|
| 642 |
+
is_pred_equation = is_equation(pred)
|
| 643 |
+
|
| 644 |
+
# Truncate equations chains in case of assignment, this doesn't change any of the above values,
|
| 645 |
+
# so no need to recompute them
|
| 646 |
+
if is_gold_assignment:
|
| 647 |
+
gold = Eq(
|
| 648 |
+
take_first_relation(gold).lhs, take_last_relation(gold).rhs, evaluate=False
|
| 649 |
+
)
|
| 650 |
+
if is_pred_assignment:
|
| 651 |
+
pred = Eq(
|
| 652 |
+
take_first_relation(pred).lhs, take_last_relation(pred).rhs, evaluate=False
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# We follow what the gold format is
|
| 656 |
+
# 1 and 9=1 -> 1,1
|
| 657 |
+
if is_pred_equation and not is_gold_equation:
|
| 658 |
+
# Unwrap pred
|
| 659 |
+
pred = take_last_relation(pred).rhs
|
| 660 |
+
|
| 661 |
+
# We respect what the pred format is only if the gold is assignment so that x=1 and 1 -> 1,1, but not 2x + z = 1 and 1 -> 1,1
|
| 662 |
+
elif is_gold_assignment and not is_pred_equation:
|
| 663 |
+
gold = take_last_relation(gold).rhs
|
| 664 |
+
|
| 665 |
+
if is_relation(gold) and isinstance(pred, Set):
|
| 666 |
+
# This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo)
|
| 667 |
+
# We also unwrap the functions because othewise it creates some conditional set based on the function name
|
| 668 |
+
try:
|
| 669 |
+
gold = unwrap_fcs(gold).as_set()
|
| 670 |
+
except Exception:
|
| 671 |
+
pass
|
| 672 |
+
|
| 673 |
+
if allow_set_relation_comp and is_relation(pred) and isinstance(gold, Set):
|
| 674 |
+
try:
|
| 675 |
+
pred = unwrap_fcs(pred).as_set()
|
| 676 |
+
except Exception:
|
| 677 |
+
pass
|
| 678 |
+
|
| 679 |
+
# Start with simple str and expr comparisson as it's the fastest
|
| 680 |
+
# str comparison is better, than simple eq, because it will also handle missarangments
|
| 681 |
+
if sympy_str_eq(gold, pred):
|
| 682 |
+
return True
|
| 683 |
+
|
| 684 |
+
# Support for equations
|
| 685 |
+
if is_relation(gold) and is_relation(pred):
|
| 686 |
+
return sympy_compare_relational(gold, pred, float_rounding, numeric_precision)
|
| 687 |
+
|
| 688 |
+
elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)):
|
| 689 |
+
return sympy_compare_sets(gold, pred, float_rounding, numeric_precision)
|
| 690 |
+
|
| 691 |
+
# Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r)
|
| 692 |
+
elif isinstance(gold, Symbol) or isinstance(pred, Symbol):
|
| 693 |
+
return sympy_compare_symbols(gold, pred)
|
| 694 |
+
|
| 695 |
+
elif isinstance(gold, (Basic, MatrixBase)) and isinstance(
|
| 696 |
+
pred, (Basic, MatrixBase)
|
| 697 |
+
):
|
| 698 |
+
# Mostly so that 0.333333 = 1/3
|
| 699 |
+
if sympy_numeric_eq(gold, pred, float_rounding, numeric_precision):
|
| 700 |
+
return True
|
| 701 |
+
# Then try symbolic equality
|
| 702 |
+
if sympy_symbolic_eq(gold, pred):
|
| 703 |
+
return True
|
| 704 |
+
|
| 705 |
+
return False
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
complex_number_pattern = re.compile(
|
| 709 |
+
r"""
|
| 710 |
+
# Complex number indicators
|
| 711 |
+
\\mathbb\{C\}| # Complex number set ℂ
|
| 712 |
+
\\i\b| # Complex i
|
| 713 |
+
\bi\b| # Standalone i
|
| 714 |
+
\\text\{i\}| # Text i
|
| 715 |
+
\\mathrm\{i\}| # Roman i
|
| 716 |
+
\\imath\b| # Alternative i notation
|
| 717 |
+
|
| 718 |
+
# Matrix operations
|
| 719 |
+
\\det| # Determinant
|
| 720 |
+
\\operatorname\{tr\}| # Trace
|
| 721 |
+
\\operatorname\{rank\}| # Rank
|
| 722 |
+
\\text\{rank\}|
|
| 723 |
+
\\arg\{| # Complex argument
|
| 724 |
+
\\Re\{| # Real part
|
| 725 |
+
\\Im\{| # Imaginary part
|
| 726 |
+
\\operatorname\{Re\}| # Real part alternate
|
| 727 |
+
\\operatorname\{Im\}| # Imaginary part alternate
|
| 728 |
+
\\text\{Re\}| # Real part text
|
| 729 |
+
\\text\{Im\} # Imaginary part text
|
| 730 |
+
""",
|
| 731 |
+
re.VERBOSE,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def should_treat_as_complex(latex_str: str) -> bool:
|
| 736 |
+
"""
|
| 737 |
+
Returns True if the latex string likely contains complex numbers, matrices, or vectors.
|
| 738 |
+
"""
|
| 739 |
+
|
| 740 |
+
return bool(complex_number_pattern.search(latex_str))
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
def verify(
|
| 744 |
+
gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
|
| 745 |
+
target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
|
| 746 |
+
float_rounding: int = 6,
|
| 747 |
+
numeric_precision: int = 15,
|
| 748 |
+
strict: bool = True,
|
| 749 |
+
allow_set_relation_comp: bool = False,
|
| 750 |
+
timeout_seconds: int | None = 5,
|
| 751 |
+
raise_on_error: bool = False,
|
| 752 |
+
) -> bool:
|
| 753 |
+
"""Verifies if the target expression matches the gold expression using multiple comparison strategies.
|
| 754 |
+
|
| 755 |
+
This function implements a comprehensive comparison system for mathematical expressions,
|
| 756 |
+
handling various types of mathematical objects (numbers, expressions, sets, matrices, etc.)
|
| 757 |
+
with multiple fallback strategies.
|
| 758 |
+
|
| 759 |
+
Note:
|
| 760 |
+
- It's expected that both gold and pred has been parsed with math_verify.parse function.
|
| 761 |
+
- Function is not symmetric, gold answer should be passed as gold and prediction as pred. The non-symmetric nature appears at assignment simplification and equation interval conversion.
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
gold: The reference/correct expression(s). Can be:
|
| 765 |
+
- A single SymPy expression (Basic or MatrixBase)
|
| 766 |
+
- A string
|
| 767 |
+
- A list of any of the above
|
| 768 |
+
target: The expression(s) to verify. Same types as gold.
|
| 769 |
+
float_rounding: Number of decimal places to round floats to. Defaults to 6.
|
| 770 |
+
numeric_precision: Number of decimal places to consider for numeric comparisons. Defaults to 15.
|
| 771 |
+
- If you know the evaluated expressions will be small, you should increase this. See: https://docs.sympy.org/latest/modules/evalf.html
|
| 772 |
+
strict: Whether to enforce strict comparison mode. Defaults to True.
|
| 773 |
+
- In strict mode: Variables matter and sets are not comparable with tuples
|
| 774 |
+
- In non-strict mode: Variables are matched by position and sets can be compared with tuples
|
| 775 |
+
timeout_seconds: Maximum time in seconds to spend on any single comparison operation.
|
| 776 |
+
Defaults to 5 seconds. Any timeout seconds > 0 or not None will result in the function to raise a ValueError if it's called in a threaded environment.
|
| 777 |
+
allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
|
| 778 |
+
- If True, set - relation comparison will be allowed in all cases.
|
| 779 |
+
- If False, set - relation comparison will be allowed only if the prediction is a set.
|
| 780 |
+
raise_on_error: Whether to raise an exception if an error occurs during comparison or return False. Defaults to False.
|
| 781 |
+
|
| 782 |
+
Returns:
|
| 783 |
+
bool: True if target matches gold according to any of the comparison strategies,
|
| 784 |
+
False otherwise.
|
| 785 |
+
|
| 786 |
+
Comparison Strategy:
|
| 787 |
+
1. String to String comparison
|
| 788 |
+
2. Numeric expressions: Comparison within specified precision
|
| 789 |
+
3. Symbolic equality through simplification
|
| 790 |
+
4. Special handling for:
|
| 791 |
+
- Relational expressions (equations/inequalities)
|
| 792 |
+
- Sets and intervals
|
| 793 |
+
- Matrices and vectors
|
| 794 |
+
- Complex numbers
|
| 795 |
+
5. Robust error handling with timeout protection
|
| 796 |
+
|
| 797 |
+
Example:
|
| 798 |
+
>>> verify(sympy.Rational(1, 3), 0.333333) # Numeric comparison
|
| 799 |
+
True
|
| 800 |
+
>>> verify(sympy.Symbol('x') + 1, sympy.Symbol('y') + 1, strict=False) # Variable matching
|
| 801 |
+
True
|
| 802 |
+
>>> verify(sympy.FiniteSet(1, 2), sympy.Tuple(1, 2), strict=False) # Set-tuple comparison
|
| 803 |
+
True
|
| 804 |
+
"""
|
| 805 |
+
|
| 806 |
+
global TIMEOUT_WARNING_SHOWN
|
| 807 |
+
if not TIMEOUT_WARNING_SHOWN and (timeout_seconds is None or timeout_seconds <= 0):
|
| 808 |
+
logger.warning(
|
| 809 |
+
"Timeout is disabled as timeout_seconds is None or <= 0, you must provide \
|
| 810 |
+
the logic for timeout interuption yourself to prevent code getting stuck."
|
| 811 |
+
)
|
| 812 |
+
TIMEOUT_WARNING_SHOWN = True
|
| 813 |
+
|
| 814 |
+
@timeout(timeout_seconds=timeout_seconds)
|
| 815 |
+
def compare_single_extraction(
|
| 816 |
+
gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str
|
| 817 |
+
) -> bool:
|
| 818 |
+
# If both are sympy expressions, we can use sympy to compare them
|
| 819 |
+
if isinstance(gold, (Basic, MatrixBase)) and isinstance(
|
| 820 |
+
target, (Basic, MatrixBase)
|
| 821 |
+
):
|
| 822 |
+
return sympy_expr_eq(
|
| 823 |
+
gold, target, float_rounding, numeric_precision, allow_set_relation_comp, strict
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
# We don't support str / sympy.Expr comparison. Imo there is no point in doing this, as chances
|
| 827 |
+
# of this happening are very low. The only why one of them is not converted to sympy expression
|
| 828 |
+
# is usually because the parsing logic failed in this case we should improve the parsing logic
|
| 829 |
+
# instead of somehow fixing adhoc.
|
| 830 |
+
elif isinstance(gold, str) and isinstance(target, str):
|
| 831 |
+
# We just do string comparison for everything else
|
| 832 |
+
gold = gold.strip()
|
| 833 |
+
target = target.strip()
|
| 834 |
+
|
| 835 |
+
# Ensure it's both not empty and equal
|
| 836 |
+
return len(gold) > 0 and len(target) > 0 and gold == target
|
| 837 |
+
|
| 838 |
+
return False
|
| 839 |
+
|
| 840 |
+
def compare_single_extraction_wrapper(g, t):
|
| 841 |
+
try:
|
| 842 |
+
return compare_single_extraction(g, t)
|
| 843 |
+
|
| 844 |
+
except ValueError as e:
|
| 845 |
+
if str(e) == "signal only works in main thread of the main interpreter":
|
| 846 |
+
raise ValueError(
|
| 847 |
+
"Math-Verify doesn't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None, which will run without timeout (and signal handling). In this case you need to handle the timeouting yourself."
|
| 848 |
+
) from e
|
| 849 |
+
else:
|
| 850 |
+
if raise_on_error:
|
| 851 |
+
raise e from e
|
| 852 |
+
else:
|
| 853 |
+
logger.debug("Error during comparison", exc_info=True)
|
| 854 |
+
return False
|
| 855 |
+
except Exception as e:
|
| 856 |
+
#! Do not attempt to print out the g and t during handling of exception
|
| 857 |
+
# Because a) it can throw an exception itself and b) it can cause it to be stuck forever during str conversion
|
| 858 |
+
if raise_on_error:
|
| 859 |
+
raise e from e
|
| 860 |
+
else:
|
| 861 |
+
logger.debug("Error during comparison", exc_info=True)
|
| 862 |
+
return False
|
| 863 |
+
except TimeoutException as e:
|
| 864 |
+
if raise_on_error:
|
| 865 |
+
raise TimeoutException("Timeout during comparison") from e
|
| 866 |
+
else:
|
| 867 |
+
logger.warning("Timeout during comparison")
|
| 868 |
+
return False
|
| 869 |
+
|
| 870 |
+
if not isinstance(gold, list):
|
| 871 |
+
gold = [gold]
|
| 872 |
+
if not isinstance(target, list):
|
| 873 |
+
target = [target]
|
| 874 |
+
|
| 875 |
+
return any(
|
| 876 |
+
compare_single_extraction_wrapper(g, t) for g, t in product(gold, target)
|
| 877 |
+
)
|