| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import logging |
| import re |
| from itertools import product |
|
|
| from latex2sympy2_extended import is_expr_of_only_symbols |
| from latex2sympy2_extended.logic import And |
| from latex2sympy2_extended.sets import FiniteSet |
| from sympy import ( |
| Basic, |
| E, |
| Eq, |
| Float, |
| GreaterThan, |
| Interval, |
| LessThan, |
| MatrixBase, |
| MatrixExpr, |
| Mul, |
| Number, |
| Integer, |
| Rational, |
| Set, |
| StrictGreaterThan, |
| StrictLessThan, |
| Symbol, |
| Tuple, |
| default_sort_key, |
| nan, |
| ordered, |
| simplify, |
| solve, |
| zoo, |
| UnevaluatedExpr, |
| ) |
| from sympy import FiniteSet as SympyFiniteSet |
| from sympy.core.function import UndefinedFunction |
| from sympy.core.relational import Relational |
|
|
| from math_verify.errors import TimeoutException |
| from math_verify.utils import timeout |
|
|
| logger = logging.getLogger(__name__) |
|
|
| TIMEOUT_WARNING_SHOWN = False |
|
|
|
|
| INVERSE_RELATIONS = { |
| GreaterThan: LessThan, |
| LessThan: GreaterThan, |
| StrictGreaterThan: StrictLessThan, |
| StrictLessThan: StrictGreaterThan, |
| Eq: Eq, |
| } |
|
|
|
|
| def safe_sympy_doit(a: Basic | MatrixBase): |
| """Safely execute doit() on a sympy expression, catching exceptions. |
| Doit in sympy will evaluate expressions it will pass the expression tree and evluate nodes. |
| For example for 1+1+1 it will evaluate the additions and return 3. One issue with it is that it maybe |
| evaluates too much as integrals will also be evaluated. |
| As we are using latex2sympy2_extended, evaluates are lazy and only evaluated when needed. |
| |
| Args: |
| a: A sympy Basic or MatrixBase expression to evaluate |
| |
| Returns: |
| The result of a.doit() if successful, otherwise returns the original expression |
| """ |
| try: |
| return a.doit() |
| except Exception: |
| pass |
| return a |
|
|
|
|
| def get_pct_val(expr, default=None): |
| if isinstance(expr, Mul) and len(expr.args) == 2 and expr.args[1] == UnevaluatedExpr(Rational(1, 100)): |
| return expr.args[0] |
| return default |
|
|
| def is_atomic_or_pct_atomic(expr: Basic | MatrixBase, atomic_type: type) -> bool: |
| """Check if expression is either an atomic type or percentage atomic type. |
| |
| Args: |
| expr: The sympy expression to check |
| atomic_type: The atomic type to check for |
| |
| Returns: |
| True if expr is atomic_type or percentage atomic type, False otherwise |
| """ |
| return isinstance(expr, atomic_type) or ( |
| |
| |
| isinstance(expr, Mul) |
| and len(expr.args) == 2 |
| and expr.args[1] == UnevaluatedExpr(Rational(1, 100)) |
| and isinstance(expr.args[0], atomic_type) |
| ) |
|
|
|
|
| def sympy_numeric_eq( |
| a: Basic | MatrixBase, |
| b: Basic | MatrixBase, |
| float_rounding: int, |
| numeric_precision: int, |
| ): |
| """Compare two sympy expressions numerically with given precision. |
| |
| Args: |
| a: First sympy expression |
| b: Second sympy expression |
| precision: Number of decimal places to compare |
| |
| Returns: |
| True if expressions are numerically equal within precision, False otherwise |
| """ |
| |
| |
| if isinstance(a, (MatrixBase, MatrixExpr)) and isinstance( |
| b, (MatrixBase, MatrixExpr) |
| ): |
| a = safe_sympy_doit(a) |
| b = safe_sympy_doit(b) |
|
|
| |
| if ( |
| isinstance(a, (MatrixBase)) |
| and isinstance(b, (MatrixBase)) |
| and a.shape == b.shape |
| ): |
| return all( |
| sympy_numeric_eq(a_elem, b_elem, float_rounding, numeric_precision) |
| for a_elem, b_elem in zip(a.flat(), b.flat(), strict=True) |
| ) |
|
|
| |
| elif is_atomic_or_pct_atomic(a, Number) or is_atomic_or_pct_atomic(b, Number): |
| |
| if is_atomic_or_pct_atomic(a, Float) or is_atomic_or_pct_atomic(b, Float): |
| a = safe_sympy_doit(a) |
| b = safe_sympy_doit(b) |
| try: |
| return a.round(float_rounding) == b.round(float_rounding) |
| except Exception: |
| pass |
| |
| elif is_atomic_or_pct_atomic(a, Integer) and is_atomic_or_pct_atomic(b, Integer): |
| a = get_pct_val(a, a) |
| b = get_pct_val(b, b) |
| return a == b |
| else: |
| return safe_sympy_doit(a) == safe_sympy_doit(b) |
|
|
| else: |
| try: |
| return (a - b).evalf(chop=True, n=numeric_precision) == 0 |
| except Exception: |
| pass |
|
|
| return False |
|
|
|
|
| def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: |
| """Compare two sympy expressions symbolically. |
| |
| Args: |
| a: First sympy expression |
| b: Second sympy expression |
| |
| Returns: |
| True if expressions are symbolically equal, False otherwise |
| """ |
| try: |
| a_b_diff = simplify((a - b)) |
| if isinstance(a_b_diff, MatrixBase) and a_b_diff.is_zero_matrix: |
| return True |
| elif isinstance(a_b_diff, Basic) and a_b_diff.is_zero: |
| return True |
| except Exception: |
| pass |
|
|
| return False |
|
|
|
|
| def unwrap_eq(s): |
| if is_assignment_relation(s): |
| return take_last_relation(s).rhs |
| return s |
|
|
| def sort_key(x): |
| try: |
| return default_sort_key(unwrap_eq(x).evalf()) |
| except Exception: |
| return default_sort_key(unwrap_eq(x)) |
|
|
| def sympy_deep_compare_set_and_tuple( |
| gold: SympyFiniteSet | Tuple, |
| pred: SympyFiniteSet | Tuple, |
| float_rounding: int, |
| numeric_precision: int, |
| ) -> bool: |
| """Compare two finite sets by comparing each element with given precision. |
| |
| Args: |
| a: First finite set |
| b: Second finite set |
| precision: Number of decimal places to compare |
| |
| Returns: |
| True if sets contain equal elements within precision, False otherwise |
| |
| Note: in order to fully support finite sets, we should ideally do kartesian product comparison |
| but this is not implemented yet. We kinda hope sympy will order the elements. |
| """ |
|
|
| |
| if len(gold) == len(pred): |
| if isinstance(gold, SympyFiniteSet): |
| gold_args = list(ordered(gold.args, keys=sort_key, default=False)) |
| pred_args = list(ordered(pred.args, keys=sort_key, default=False)) |
|
|
| elif isinstance(gold, Tuple) and isinstance(pred, FiniteSet): |
| |
| pred_args = pred._unsorted_args |
| gold_args = gold.args |
|
|
| elif isinstance(pred, SympyFiniteSet): |
| pred_args = list(ordered(pred.args, keys=sort_key, default=False)) |
| gold_args = gold.args |
| else: |
| gold_args = gold.args |
| pred_args = pred.args |
|
|
| return all( |
| sympy_expr_eq(a, b, float_rounding, numeric_precision) |
| for a, b in zip(gold_args, pred_args, strict=True) |
| ) |
|
|
| return False |
|
|
|
|
| def sympy_compare_interval( |
| a: Interval, b: Interval, float_rounding: int, numeric_precision: int |
| ) -> bool: |
| """Compare two intervals. |
| |
| Args: |
| a: First interval |
| b: Second interval |
| precision: Number of decimal places to compare endpoints |
| |
| Returns: |
| True if intervals are equal, False otherwise |
| """ |
| return ( |
| a.left_open == b.left_open |
| and a.right_open == b.right_open |
| and sympy_expr_eq(a.start, b.start, float_rounding, numeric_precision) |
| and sympy_expr_eq(a.end, b.end, float_rounding, numeric_precision) |
| ) |
|
|
|
|
| def sympy_solve_and_compare( |
| gold: Relational, pred: Relational, float_rounding: int, numeric_precision: int |
| ) -> bool: |
| solved_gold = list(ordered(solve(gold, gold.free_symbols))) |
| solved_pred = list(ordered(solve(pred, pred.free_symbols))) |
| |
| if isinstance(gold, Eq) and isinstance(pred, Eq): |
| return all( |
| all( |
| g_k == p_k |
| and sympy_expr_eq(g_v, p_v, float_rounding, numeric_precision) |
| for (g_k, g_v), (p_k, p_v) in zip( |
| sorted(g.items()), sorted(p.items()), strict=True |
| ) |
| ) |
| for g, p in zip(ordered(solved_gold, keys=sort_key, default=False), ordered(solved_pred, keys=sort_key, default=False), strict=True) |
| ) |
| else: |
| return sympy_expr_eq( |
| solved_gold, solved_pred, float_rounding, numeric_precision |
| ) |
|
|
|
|
| def sympy_compare_relational( |
| gold: Relational | And, |
| pred: Relational | And, |
| float_rounding: int, |
| numeric_precision: int, |
| ) -> bool: |
| """Compare two relational expressions. |
| |
| Args: |
| gold: First relational expression |
| pred: Second relational expression |
| precision: Number of decimal places to compare |
| |
| Returns: |
| True if relations are equivalent, False otherwise |
| """ |
|
|
| if isinstance(gold, And) and isinstance(pred, And): |
| return all( |
| sympy_compare_relational(g, p, float_rounding, numeric_precision) |
| for g, p in zip(gold._unsorted_args, pred._unsorted_args, strict=True) |
| ) |
|
|
| elif not isinstance(gold, Relational) or not isinstance(pred, Relational): |
| return False |
|
|
| |
| def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool: |
| try: |
| return sympy_expr_eq( |
| a.lhs - a.rhs, b.rhs - b.lhs, float_rounding, numeric_precision |
| ) |
| except Exception: |
| pass |
| return False |
|
|
| |
| try: |
| if type(gold) is type(pred) and sympy_expr_eq( |
| gold.lhs - gold.rhs, pred.lhs - pred.rhs, float_rounding, numeric_precision |
| ): |
| return True |
| except Exception: |
| pass |
|
|
| |
| if INVERSE_RELATIONS[type(gold)] is type(pred) and are_flipped_inequalities_equal( |
| gold, pred |
| ): |
| return True |
|
|
| try: |
| if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision): |
| return True |
| except Exception: |
| pass |
|
|
| return False |
|
|
|
|
| def sympy_str_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool: |
| """Compare two sympy expressions by string representation. |
| |
| Args: |
| a: First sympy expression |
| b: Second sympy expression |
| |
| Returns: |
| True if string representations are equal, False otherwise |
| """ |
| |
| if a == nan or a == zoo: |
| raise ValueError("Can't evaluate nan or zoo") |
| try: |
| return a == b |
| except Exception: |
| pass |
| return False |
|
|
|
|
| def sympy_compare_sets( |
| gold: Set | Basic | MatrixBase | Tuple, |
| pred: Set | Basic | MatrixBase | Tuple, |
| float_rounding: int, |
| numeric_precision: int, |
| ) -> bool: |
| """Compare two sympy sets for equality using multiple methods. |
| |
| Args: |
| gold: First sympy set (expected) |
| pred: Second sympy set (predicted) |
| precision: Number of decimal places to compare |
| |
| Returns: |
| True if sets are equal by any comparison method, False otherwise |
| """ |
| |
| a_set = gold if isinstance(gold, (Set, Tuple)) else SympyFiniteSet(gold) |
| b_set = pred if isinstance(pred, (Set, Tuple)) else SympyFiniteSet(pred) |
|
|
| |
| if isinstance(a_set, Interval) and isinstance(b_set, Interval): |
| return sympy_compare_interval(a_set, b_set, float_rounding, numeric_precision) |
|
|
| |
| if a_set == b_set: |
| return True |
|
|
| |
| try: |
| if ( |
| isinstance(a_set, Set) |
| and isinstance(b_set, Set) |
| and a_set.symmetric_difference(b_set).is_empty |
| ): |
| return True |
| except Exception: |
| pass |
|
|
| |
| if isinstance(a_set, (SympyFiniteSet, Tuple)) and isinstance( |
| b_set, (SympyFiniteSet, Tuple) |
| ): |
| return sympy_deep_compare_set_and_tuple( |
| a_set, b_set, float_rounding, numeric_precision |
| ) |
|
|
| |
| |
| if isinstance(a_set, Interval) and isinstance(b_set, (SympyFiniteSet, Tuple)): |
| if a_set.is_open and len(b_set) == 2: |
| return sympy_deep_compare_set_and_tuple( |
| Tuple(a_set.start, a_set.end), b_set, float_rounding, numeric_precision |
| ) |
|
|
| if isinstance(b_set, Interval) and isinstance(a_set, (SympyFiniteSet, Tuple)): |
| if b_set.is_open and len(a_set) == 2: |
| return sympy_deep_compare_set_and_tuple( |
| a_set, Tuple(b_set.start, b_set.end), float_rounding, numeric_precision |
| ) |
|
|
| return False |
|
|
|
|
| def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool: |
| """Compare two sympy expressions where at least one is a Symbol. |
| |
| Handles special cases: |
| - One is Symbol and other is E (limitation of parsed expressions) |
| - One is multiplication of symbols and other is single symbol (concatenated comparison) |
| |
| Args: |
| gold: First sympy expression (expected) |
| pred: Second sympy expression (predicted) |
| precision: Number of decimal places to compare |
| |
| Returns: |
| True if expressions are equal by any comparison method, False otherwise |
| """ |
| |
| if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or ( |
| isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E |
| ): |
| return True |
|
|
| |
| |
| if ( |
| isinstance(gold, Symbol) |
| and isinstance(pred, Mul) |
| and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args) |
| ): |
| concat_pred = "".join( |
| arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args |
| ) |
| return gold.name.lower() == concat_pred.lower() |
|
|
| if ( |
| isinstance(pred, Symbol) |
| and isinstance(gold, Mul) |
| and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args) |
| ): |
| concat_gold = "".join( |
| arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args |
| ) |
| return pred.name.lower() == concat_gold.lower() |
|
|
| |
| if isinstance(gold, Symbol) and isinstance(pred, Symbol): |
| g_name = gold.name |
| p_name = pred.name |
| if len(p_name) > 1: |
| p_name = p_name.lower() |
| if len(g_name) > 1: |
| g_name = g_name.lower() |
| return g_name == p_name |
|
|
| return str(gold) == str(pred) |
|
|
|
|
| def is_relation(expr: Basic | MatrixBase) -> bool: |
| """Check if an expression is a relational expression. |
| |
| Args: |
| expr: The expression to check |
| Returns: |
| bool: True if expr is a relational expression or And of relations, False otherwise |
| """ |
| if isinstance(expr, Relational): |
| return True |
|
|
| if isinstance(expr, And) and len(expr._unsorted_args) > 0: |
| return all(isinstance(arg, Relational) for arg in expr._unsorted_args) |
|
|
| return False |
|
|
|
|
| def is_equation(expr: Basic | MatrixBase) -> bool: |
| """Check if an expression is an equation. |
| |
| Args: |
| expr: The expression to check |
| Returns: |
| bool: True if expr is an equation, False otherwise |
| """ |
| if isinstance(expr, Eq): |
| return True |
|
|
| if isinstance(expr, And) and len(expr._unsorted_args) > 0: |
| return all(isinstance(arg, Eq) for arg in expr._unsorted_args) |
|
|
| return False |
|
|
|
|
| def is_assignment_relation(expr: Basic | MatrixBase) -> bool: |
| """Check if an expression is an assignment relation. E.g a=1 |
| |
| Args: |
| expr: The expression to check |
| Returns: |
| bool: True if expr is a relational expression or And of relations, False otherwise |
| """ |
| if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs): |
| return True |
|
|
| if isinstance(expr, And) and len(expr._unsorted_args) > 0: |
| return all( |
| isinstance(arg, Eq) for arg in expr._unsorted_args |
| ) and is_expr_of_only_symbols(expr._unsorted_args[0].lhs) |
|
|
| return False |
|
|
|
|
| def take_last_relation(expr: And | Relational) -> Relational: |
| """Take the last relation from an And expression.""" |
| if isinstance(expr, And): |
| return take_last_relation(expr._unsorted_args[-1]) |
| return expr |
|
|
|
|
| def take_first_relation(expr: And | Relational) -> Relational: |
| """Take the first relation from an And expression.""" |
| if isinstance(expr, And): |
| return expr._unsorted_args[0] |
| return expr |
|
|
|
|
| def unwrap_fcs(expr: Basic | MatrixBase) -> Basic | MatrixBase: |
| """Unwrap function calls to their arguments. |
| |
| For example, Function('f')(x) becomes Symbol('f_x') |
| |
| Args: |
| expr: The expression to unwrap |
| |
| Returns: |
| The unwrapped expression with functions replaced by concatenated symbols |
| """ |
| |
| if not isinstance(expr, Basic): |
| return expr |
|
|
| |
| if hasattr(expr, "func") and isinstance(expr.func, UndefinedFunction): |
| |
| func_name = expr.func.__name__ |
| |
| unwrapped_args = [str(unwrap_fcs(arg)) for arg in expr.args] |
| |
| return Symbol(f"{func_name}_{'_'.join(unwrapped_args)}") |
|
|
| |
| try: |
| new_args = [unwrap_fcs(arg) for arg in expr.args] |
| if new_args: |
| return expr.func(*new_args) |
| except Exception: |
| pass |
|
|
| return expr |
|
|
|
|
| def sympy_expr_eq( |
| gold: Basic | MatrixBase, |
| pred: Basic | MatrixBase, |
| float_rounding: int, |
| numeric_precision: int, |
| allow_set_relation_comp: bool = False, |
| strict: bool = True, |
| ) -> bool: |
| """Compare two sympy expressions for equality using multiple methods. |
| |
| Args: |
| gold: First sympy expression (expected) |
| pred: Second sympy expression (predicted) |
| precision: Number of decimal places to compare |
| allow_set_relation_comp: Whether to allow set - relation comparison. Defaults to False. |
| - If True, set - relation comparison will be allowed in all cases. |
| - If False, set - relation comparison will be allowed only if the prediction is a set. |
| strict: If true, variables do matter otherwise they don't |
| |
| Returns: |
| True if expressions are equal by any comparison method, False otherwise |
| """ |
|
|
| |
| if not strict: |
| try: |
| gold_variables = gold.free_symbols |
| pred_variables = pred.free_symbols |
| if len(gold_variables) == len(pred_variables): |
| pred = pred.subs( |
| list(zip(pred_variables, gold_variables, strict=True)) |
| ) |
| except Exception: |
| pass |
|
|
| |
| |
|
|
| |
|
|
| is_gold_assignment = is_assignment_relation(gold) |
| is_pred_assignment = is_assignment_relation(pred) |
| is_gold_equation = is_equation(gold) |
| is_pred_equation = is_equation(pred) |
|
|
| |
| |
| if is_gold_assignment: |
| gold = Eq( |
| take_first_relation(gold).lhs, take_last_relation(gold).rhs, evaluate=False |
| ) |
| if is_pred_assignment: |
| pred = Eq( |
| take_first_relation(pred).lhs, take_last_relation(pred).rhs, evaluate=False |
| ) |
|
|
| |
| |
| if is_pred_equation and not is_gold_equation: |
| |
| pred = take_last_relation(pred).rhs |
|
|
| |
| elif is_gold_assignment and not is_pred_equation: |
| gold = take_last_relation(gold).rhs |
|
|
| if is_relation(gold) and isinstance(pred, Set): |
| |
| |
| try: |
| gold = unwrap_fcs(gold).as_set() |
| except Exception: |
| pass |
|
|
| if allow_set_relation_comp and is_relation(pred) and isinstance(gold, Set): |
| try: |
| pred = unwrap_fcs(pred).as_set() |
| except Exception: |
| pass |
|
|
| |
| |
| if sympy_str_eq(gold, pred): |
| return True |
|
|
| |
| if is_relation(gold) and is_relation(pred): |
| return sympy_compare_relational(gold, pred, float_rounding, numeric_precision) |
|
|
| elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)): |
| return sympy_compare_sets(gold, pred, float_rounding, numeric_precision) |
|
|
| |
| elif isinstance(gold, Symbol) or isinstance(pred, Symbol): |
| return sympy_compare_symbols(gold, pred) |
|
|
| elif isinstance(gold, (Basic, MatrixBase)) and isinstance( |
| pred, (Basic, MatrixBase) |
| ): |
| |
| if sympy_numeric_eq(gold, pred, float_rounding, numeric_precision): |
| return True |
| |
| if sympy_symbolic_eq(gold, pred): |
| return True |
|
|
| return False |
|
|
|
|
| complex_number_pattern = re.compile( |
| r""" |
| # Complex number indicators |
| \\mathbb\{C\}| # Complex number set ℂ |
| \\i\b| # Complex i |
| \bi\b| # Standalone i |
| \\text\{i\}| # Text i |
| \\mathrm\{i\}| # Roman i |
| \\imath\b| # Alternative i notation |
| |
| # Matrix operations |
| \\det| # Determinant |
| \\operatorname\{tr\}| # Trace |
| \\operatorname\{rank\}| # Rank |
| \\text\{rank\}| |
| \\arg\{| # Complex argument |
| \\Re\{| # Real part |
| \\Im\{| # Imaginary part |
| \\operatorname\{Re\}| # Real part alternate |
| \\operatorname\{Im\}| # Imaginary part alternate |
| \\text\{Re\}| # Real part text |
| \\text\{Im\} # Imaginary part text |
| """, |
| re.VERBOSE, |
| ) |
|
|
|
|
| def should_treat_as_complex(latex_str: str) -> bool: |
| """ |
| Returns True if the latex string likely contains complex numbers, matrices, or vectors. |
| """ |
|
|
| return bool(complex_number_pattern.search(latex_str)) |
|
|
|
|
| def verify( |
| gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, |
| target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, |
| float_rounding: int = 6, |
| numeric_precision: int = 15, |
| strict: bool = True, |
| allow_set_relation_comp: bool = False, |
| timeout_seconds: int | None = 5, |
| raise_on_error: bool = False, |
| ) -> bool: |
| """Verifies if the target expression matches the gold expression using multiple comparison strategies. |
| |
| This function implements a comprehensive comparison system for mathematical expressions, |
| handling various types of mathematical objects (numbers, expressions, sets, matrices, etc.) |
| with multiple fallback strategies. |
| |
| Note: |
| - It's expected that both gold and pred has been parsed with math_verify.parse function. |
| - Function is not symmetric, gold answer should be passed as gold and prediction as pred. The non-symmetric nature appears at assignment simplification and equation interval conversion. |
| |
| Args: |
| gold: The reference/correct expression(s). Can be: |
| - A single SymPy expression (Basic or MatrixBase) |
| - A string |
| - A list of any of the above |
| target: The expression(s) to verify. Same types as gold. |
| float_rounding: Number of decimal places to round floats to. Defaults to 6. |
| numeric_precision: Number of decimal places to consider for numeric comparisons. Defaults to 15. |
| - If you know the evaluated expressions will be small, you should increase this. See: https://docs.sympy.org/latest/modules/evalf.html |
| strict: Whether to enforce strict comparison mode. Defaults to True. |
| - In strict mode: Variables matter and sets are not comparable with tuples |
| - In non-strict mode: Variables are matched by position and sets can be compared with tuples |
| timeout_seconds: Maximum time in seconds to spend on any single comparison operation. |
| Defaults to 5 seconds. Any timeout seconds > 0 or not None will result in the function to raise a ValueError if it's called in a threaded environment. |
| allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False. |
| - If True, set - relation comparison will be allowed in all cases. |
| - If False, set - relation comparison will be allowed only if the prediction is a set. |
| raise_on_error: Whether to raise an exception if an error occurs during comparison or return False. Defaults to False. |
| |
| Returns: |
| bool: True if target matches gold according to any of the comparison strategies, |
| False otherwise. |
| |
| Comparison Strategy: |
| 1. String to String comparison |
| 2. Numeric expressions: Comparison within specified precision |
| 3. Symbolic equality through simplification |
| 4. Special handling for: |
| - Relational expressions (equations/inequalities) |
| - Sets and intervals |
| - Matrices and vectors |
| - Complex numbers |
| 5. Robust error handling with timeout protection |
| |
| Example: |
| >>> verify(sympy.Rational(1, 3), 0.333333) # Numeric comparison |
| True |
| >>> verify(sympy.Symbol('x') + 1, sympy.Symbol('y') + 1, strict=False) # Variable matching |
| True |
| >>> verify(sympy.FiniteSet(1, 2), sympy.Tuple(1, 2), strict=False) # Set-tuple comparison |
| True |
| """ |
|
|
| global TIMEOUT_WARNING_SHOWN |
| if not TIMEOUT_WARNING_SHOWN and (timeout_seconds is None or timeout_seconds <= 0): |
| logger.warning( |
| "Timeout is disabled as timeout_seconds is None or <= 0, you must provide \ |
| the logic for timeout interuption yourself to prevent code getting stuck." |
| ) |
| TIMEOUT_WARNING_SHOWN = True |
|
|
| @timeout(timeout_seconds=timeout_seconds) |
| def compare_single_extraction( |
| gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str |
| ) -> bool: |
| |
| if isinstance(gold, (Basic, MatrixBase)) and isinstance( |
| target, (Basic, MatrixBase) |
| ): |
| return sympy_expr_eq( |
| gold, target, float_rounding, numeric_precision, allow_set_relation_comp, strict |
| ) |
|
|
| |
| |
| |
| |
| elif isinstance(gold, str) and isinstance(target, str): |
| |
| gold = gold.strip() |
| target = target.strip() |
|
|
| |
| return len(gold) > 0 and len(target) > 0 and gold == target |
|
|
| return False |
|
|
| def compare_single_extraction_wrapper(g, t): |
| try: |
| return compare_single_extraction(g, t) |
|
|
| except ValueError as e: |
| if str(e) == "signal only works in main thread of the main interpreter": |
| raise ValueError( |
| "Math-Verify doesn't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None, which will run without timeout (and signal handling). In this case you need to handle the timeouting yourself." |
| ) from e |
| else: |
| if raise_on_error: |
| raise e from e |
| else: |
| logger.debug("Error during comparison", exc_info=True) |
| return False |
| except Exception as e: |
| |
| |
| if raise_on_error: |
| raise e from e |
| else: |
| logger.debug("Error during comparison", exc_info=True) |
| return False |
| except TimeoutException as e: |
| if raise_on_error: |
| raise TimeoutException("Timeout during comparison") from e |
| else: |
| logger.warning("Timeout during comparison") |
| return False |
|
|
| if not isinstance(gold, list): |
| gold = [gold] |
| if not isinstance(target, list): |
| target = [target] |
|
|
| return any( |
| compare_single_extraction_wrapper(g, t) for g, t in product(gold, target) |
| ) |
|
|