Varshithdharmajv commited on
Commit
622d468
·
verified ·
1 Parent(s): c21dc0d

Upload math_verify/grader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. math_verify/grader.py +877 -0
math_verify/grader.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2024 The HuggingFace Team
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ # Heavily inspired by https://github.com/QwenLM/Qwen2.5-Math and https://github.com/huggingface/lm-evaluation-harness
24
+ import logging
25
+ import re
26
+ from itertools import product
27
+
28
+ from latex2sympy2_extended import is_expr_of_only_symbols
29
+ from latex2sympy2_extended.logic import And
30
+ from latex2sympy2_extended.sets import FiniteSet
31
+ from sympy import (
32
+ Basic,
33
+ E,
34
+ Eq,
35
+ Float,
36
+ GreaterThan,
37
+ Interval,
38
+ LessThan,
39
+ MatrixBase,
40
+ MatrixExpr,
41
+ Mul,
42
+ Number,
43
+ Rational,
44
+ Set,
45
+ StrictGreaterThan,
46
+ StrictLessThan,
47
+ Symbol,
48
+ Tuple,
49
+ default_sort_key,
50
+ nan,
51
+ ordered,
52
+ simplify,
53
+ solve,
54
+ zoo,
55
+ )
56
+ from sympy import FiniteSet as SympyFiniteSet
57
+ from sympy.core.function import UndefinedFunction
58
+ from sympy.core.relational import Relational
59
+
60
+ from math_verify.errors import TimeoutException
61
+ from math_verify.utils import timeout
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+ TIMEOUT_WARNING_SHOWN = False
66
+
67
+
68
+ INVERSE_RELATIONS = {
69
+ GreaterThan: LessThan,
70
+ LessThan: GreaterThan,
71
+ StrictGreaterThan: StrictLessThan,
72
+ StrictLessThan: StrictGreaterThan,
73
+ Eq: Eq,
74
+ }
75
+
76
+
77
+ def safe_sympy_doit(a: Basic | MatrixBase):
78
+ """Safely execute doit() on a sympy expression, catching exceptions.
79
+ Doit in sympy will evaluate expressions it will pass the expression tree and evluate nodes.
80
+ For example for 1+1+1 it will evaluate the additions and return 3. One issue with it is that it maybe
81
+ evaluates too much as integrals will also be evaluated.
82
+ As we are using latex2sympy2_extended, evaluates are lazy and only evaluated when needed.
83
+
84
+ Args:
85
+ a: A sympy Basic or MatrixBase expression to evaluate
86
+
87
+ Returns:
88
+ The result of a.doit() if successful, otherwise returns the original expression
89
+ """
90
+ try:
91
+ return a.doit()
92
+ except Exception:
93
+ pass
94
+ return a
95
+
96
+
97
+ def is_atomic_or_pct_atomic(expr: Basic | MatrixBase, atomic_type: type) -> bool:
98
+ """Check if expression is either an atomic type or percentage atomic type.
99
+
100
+ Args:
101
+ expr: The sympy expression to check
102
+ atomic_type: The atomic type to check for
103
+
104
+ Returns:
105
+ True if expr is atomic_type or percentage atomic type, False otherwise
106
+ """
107
+ return isinstance(expr, atomic_type) or (
108
+ # Check for percentage representation: latex2sympy_extended converts "X%" into X*Rational(1,100)
109
+ # So we detect percentages by looking for this multiplication structure
110
+ isinstance(expr, Mul)
111
+ and len(expr.args) == 2
112
+ and expr.args[1] == Rational(1, 100)
113
+ and isinstance(expr.args[0], atomic_type)
114
+ )
115
+
116
+
117
+ def sympy_numeric_eq(
118
+ a: Basic | MatrixBase,
119
+ b: Basic | MatrixBase,
120
+ float_rounding: int,
121
+ numeric_precision: int,
122
+ ):
123
+ """Compare two sympy expressions numerically with given precision.
124
+
125
+ Args:
126
+ a: First sympy expression
127
+ b: Second sympy expression
128
+ precision: Number of decimal places to compare
129
+
130
+ Returns:
131
+ True if expressions are numerically equal within precision, False otherwise
132
+ """
133
+ # Only do this when one of the two is a float, in other cases use symbolic equality as this could lead to false positives
134
+ # E.g we want 1/3 == 0.333333 to work
135
+ if isinstance(a, (MatrixBase, MatrixExpr)) and isinstance(
136
+ b, (MatrixBase, MatrixExpr)
137
+ ):
138
+ a = safe_sympy_doit(a)
139
+ b = safe_sympy_doit(b)
140
+
141
+ # If we have matrices and one of them is only made of floats, we can use the same logic as above
142
+ if (
143
+ isinstance(a, (MatrixBase))
144
+ and isinstance(b, (MatrixBase))
145
+ and a.shape == b.shape
146
+ ):
147
+ return all(
148
+ sympy_numeric_eq(a_elem, b_elem, float_rounding, numeric_precision)
149
+ for a_elem, b_elem in zip(a.flat(), b.flat(), strict=True)
150
+ )
151
+
152
+ # Ensure this also works for percentage numbers so that 0.333333% = 0.33333333333 with precision 4
153
+ elif is_atomic_or_pct_atomic(a, Number) or is_atomic_or_pct_atomic(b, Number):
154
+ # If one of them is a float or a percentage number, we can try to use float precision
155
+ if is_atomic_or_pct_atomic(a, Float) or is_atomic_or_pct_atomic(b, Float):
156
+ a = safe_sympy_doit(a)
157
+ b = safe_sympy_doit(b)
158
+ try:
159
+ return a.round(float_rounding) == b.round(float_rounding)
160
+ except Exception:
161
+ pass
162
+ else:
163
+ return safe_sympy_doit(a) == safe_sympy_doit(b)
164
+
165
+ else:
166
+ try:
167
+ return (a - b).evalf(chop=True, n=numeric_precision) == 0 # type: ignore
168
+ except Exception:
169
+ pass
170
+
171
+ return False
172
+
173
+
174
+ def sympy_symbolic_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool:
175
+ """Compare two sympy expressions symbolically.
176
+
177
+ Args:
178
+ a: First sympy expression
179
+ b: Second sympy expression
180
+
181
+ Returns:
182
+ True if expressions are symbolically equal, False otherwise
183
+ """
184
+ try:
185
+ a_b_diff = simplify((a - b)) # type: ignore
186
+ if isinstance(a_b_diff, MatrixBase) and a_b_diff.is_zero_matrix:
187
+ return True
188
+ elif isinstance(a_b_diff, Basic) and a_b_diff.is_zero:
189
+ return True
190
+ except Exception:
191
+ pass
192
+
193
+ return False
194
+
195
+
196
+ def unwrap_eq(s):
197
+ if is_assignment_relation(s):
198
+ return take_last_relation(s).rhs
199
+ return s
200
+
201
+ def sort_key(x):
202
+ try:
203
+ return default_sort_key(unwrap_eq(x).evalf())
204
+ except Exception:
205
+ return default_sort_key(unwrap_eq(x))
206
+
207
+ def sympy_deep_compare_set_and_tuple(
208
+ gold: SympyFiniteSet | Tuple,
209
+ pred: SympyFiniteSet | Tuple,
210
+ float_rounding: int,
211
+ numeric_precision: int,
212
+ ) -> bool:
213
+ """Compare two finite sets by comparing each element with given precision.
214
+
215
+ Args:
216
+ a: First finite set
217
+ b: Second finite set
218
+ precision: Number of decimal places to compare
219
+
220
+ Returns:
221
+ True if sets contain equal elements within precision, False otherwise
222
+
223
+ Note: in order to fully support finite sets, we should ideally do kartesian product comparison
224
+ but this is not implemented yet. We kinda hope sympy will order the elements.
225
+ """
226
+
227
+ # This ensures it works for {1/3} and {0.333333}
228
+ if len(gold) == len(pred):
229
+ if isinstance(gold, SympyFiniteSet):
230
+ gold_args = list(ordered(gold.args, keys=sort_key, default=False))
231
+ pred_args = list(ordered(pred.args, keys=sort_key, default=False))
232
+
233
+ elif isinstance(gold, Tuple) and isinstance(pred, FiniteSet):
234
+ # We treat the pred as tuple too
235
+ pred_args = pred._unsorted_args
236
+ gold_args = gold.args
237
+
238
+ elif isinstance(pred, SympyFiniteSet):
239
+ pred_args = list(ordered(pred.args, keys=sort_key, default=False))
240
+ gold_args = gold.args
241
+ else:
242
+ gold_args = gold.args
243
+ pred_args = pred.args
244
+
245
+ return all(
246
+ sympy_expr_eq(a, b, float_rounding, numeric_precision)
247
+ for a, b in zip(gold_args, pred_args, strict=True)
248
+ )
249
+
250
+ return False
251
+
252
+
253
+ def sympy_compare_interval(
254
+ a: Interval, b: Interval, float_rounding: int, numeric_precision: int
255
+ ) -> bool:
256
+ """Compare two intervals.
257
+
258
+ Args:
259
+ a: First interval
260
+ b: Second interval
261
+ precision: Number of decimal places to compare endpoints
262
+
263
+ Returns:
264
+ True if intervals are equal, False otherwise
265
+ """
266
+ return (
267
+ a.left_open == b.left_open
268
+ and a.right_open == b.right_open
269
+ and sympy_expr_eq(a.start, b.start, float_rounding, numeric_precision)
270
+ and sympy_expr_eq(a.end, b.end, float_rounding, numeric_precision)
271
+ )
272
+
273
+
274
+ def sympy_solve_and_compare(
275
+ gold: Relational, pred: Relational, float_rounding: int, numeric_precision: int
276
+ ) -> bool:
277
+ solved_gold = list(ordered(solve(gold, gold.free_symbols)))
278
+ solved_pred = list(ordered(solve(pred, pred.free_symbols)))
279
+ # Equalities should return list of dicts of solutions
280
+ if isinstance(gold, Eq) and isinstance(pred, Eq):
281
+ return all(
282
+ all(
283
+ g_k == p_k
284
+ and sympy_expr_eq(g_v, p_v, float_rounding, numeric_precision)
285
+ for (g_k, g_v), (p_k, p_v) in zip(
286
+ sorted(g.items()), sorted(p.items()), strict=True
287
+ )
288
+ )
289
+ for g, p in zip(ordered(solved_gold, keys=sort_key, default=False), ordered(solved_pred, keys=sort_key, default=False), strict=True)
290
+ )
291
+ else:
292
+ return sympy_expr_eq(
293
+ solved_gold, solved_pred, float_rounding, numeric_precision
294
+ )
295
+
296
+
297
+ def sympy_compare_relational(
298
+ gold: Relational | And,
299
+ pred: Relational | And,
300
+ float_rounding: int,
301
+ numeric_precision: int,
302
+ ) -> bool:
303
+ """Compare two relational expressions.
304
+
305
+ Args:
306
+ gold: First relational expression
307
+ pred: Second relational expression
308
+ precision: Number of decimal places to compare
309
+
310
+ Returns:
311
+ True if relations are equivalent, False otherwise
312
+ """
313
+
314
+ if isinstance(gold, And) and isinstance(pred, And):
315
+ return all(
316
+ sympy_compare_relational(g, p, float_rounding, numeric_precision)
317
+ for g, p in zip(gold._unsorted_args, pred._unsorted_args, strict=True)
318
+ )
319
+
320
+ elif not isinstance(gold, Relational) or not isinstance(pred, Relational):
321
+ return False
322
+
323
+ # Helper to check if expressions are equivalent when flipped
324
+ def are_flipped_inequalities_equal(a: Relational, b: Relational) -> bool:
325
+ try:
326
+ return sympy_expr_eq(
327
+ a.lhs - a.rhs, b.rhs - b.lhs, float_rounding, numeric_precision
328
+ ) # type: ignore
329
+ except Exception:
330
+ pass
331
+ return False
332
+
333
+ # Same type of relation (e.g. both <= or both >=)
334
+ try:
335
+ if type(gold) is type(pred) and sympy_expr_eq(
336
+ gold.lhs - gold.rhs, pred.lhs - pred.rhs, float_rounding, numeric_precision
337
+ ): # type: ignore
338
+ return True
339
+ except Exception:
340
+ pass
341
+
342
+ # Check flipped inequalities (a <= b equals b >= a)
343
+ if INVERSE_RELATIONS[type(gold)] is type(pred) and are_flipped_inequalities_equal( # type: ignore
344
+ gold, pred
345
+ ):
346
+ return True
347
+
348
+ try:
349
+ if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision):
350
+ return True
351
+ except Exception:
352
+ pass
353
+
354
+ return False
355
+
356
+
357
+ def sympy_str_eq(a: Basic | MatrixBase, b: Basic | MatrixBase) -> bool:
358
+ """Compare two sympy expressions by string representation.
359
+
360
+ Args:
361
+ a: First sympy expression
362
+ b: Second sympy expression
363
+
364
+ Returns:
365
+ True if string representations are equal, False otherwise
366
+ """
367
+ # We can't evaluate nan or zoo
368
+ if a == nan or a == zoo:
369
+ raise ValueError("Can't evaluate nan or zoo")
370
+ try:
371
+ return a == b
372
+ except Exception:
373
+ pass
374
+ return False
375
+
376
+
377
+ def sympy_compare_sets(
378
+ gold: Set | Basic | MatrixBase | Tuple,
379
+ pred: Set | Basic | MatrixBase | Tuple,
380
+ float_rounding: int,
381
+ numeric_precision: int,
382
+ ) -> bool:
383
+ """Compare two sympy sets for equality using multiple methods.
384
+
385
+ Args:
386
+ gold: First sympy set (expected)
387
+ pred: Second sympy set (predicted)
388
+ precision: Number of decimal places to compare
389
+
390
+ Returns:
391
+ True if sets are equal by any comparison method, False otherwise
392
+ """
393
+ # Convert non-sets to singleton sets
394
+ a_set = gold if isinstance(gold, (Set, Tuple)) else SympyFiniteSet(gold)
395
+ b_set = pred if isinstance(pred, (Set, Tuple)) else SympyFiniteSet(pred)
396
+
397
+ # If both are intervals, use interval comparison
398
+ if isinstance(a_set, Interval) and isinstance(b_set, Interval):
399
+ return sympy_compare_interval(a_set, b_set, float_rounding, numeric_precision)
400
+
401
+ # Try direct set equality
402
+ if a_set == b_set:
403
+ return True
404
+
405
+ # If both are sets, check if they are equal
406
+ try:
407
+ if (
408
+ isinstance(a_set, Set)
409
+ and isinstance(b_set, Set)
410
+ and a_set.symmetric_difference(b_set).is_empty
411
+ ):
412
+ return True
413
+ except Exception:
414
+ pass
415
+
416
+ # For finite sets, compare elements
417
+ if isinstance(a_set, (SympyFiniteSet, Tuple)) and isinstance(
418
+ b_set, (SympyFiniteSet, Tuple)
419
+ ):
420
+ return sympy_deep_compare_set_and_tuple(
421
+ a_set, b_set, float_rounding, numeric_precision
422
+ )
423
+
424
+ # Because (1,2) is parsed as Interval(1,2,left_open=True,right_open=True), it could have that the
425
+ # correct is (1,2) and predicted is 1,2, which is parsed as Set(1,2)
426
+ if isinstance(a_set, Interval) and isinstance(b_set, (SympyFiniteSet, Tuple)):
427
+ if a_set.is_open and len(b_set) == 2:
428
+ return sympy_deep_compare_set_and_tuple(
429
+ Tuple(a_set.start, a_set.end), b_set, float_rounding, numeric_precision
430
+ )
431
+
432
+ if isinstance(b_set, Interval) and isinstance(a_set, (SympyFiniteSet, Tuple)):
433
+ if b_set.is_open and len(a_set) == 2:
434
+ return sympy_deep_compare_set_and_tuple(
435
+ a_set, Tuple(b_set.start, b_set.end), float_rounding, numeric_precision
436
+ )
437
+
438
+ return False
439
+
440
+
441
+ def sympy_compare_symbols(gold: Basic | MatrixBase, pred: Basic | MatrixBase) -> bool:
442
+ """Compare two sympy expressions where at least one is a Symbol.
443
+
444
+ Handles special cases:
445
+ - One is Symbol and other is E (limitation of parsed expressions)
446
+ - One is multiplication of symbols and other is single symbol (concatenated comparison)
447
+
448
+ Args:
449
+ gold: First sympy expression (expected)
450
+ pred: Second sympy expression (predicted)
451
+ precision: Number of decimal places to compare
452
+
453
+ Returns:
454
+ True if expressions are equal by any comparison method, False otherwise
455
+ """
456
+ # Handle E vs symbol case
457
+ if (isinstance(gold, Symbol) and gold.name.lower() == "e" and pred == E) or (
458
+ isinstance(pred, Symbol) and pred.name.lower() == "e" and gold == E
459
+ ):
460
+ return True
461
+
462
+ # Handle multiplication of symbols vs single symbol, because parsing return $abc$ -> abc
463
+ # We also handle E as it's a symbol, because E will be always parsed as exp
464
+ if (
465
+ isinstance(gold, Symbol)
466
+ and isinstance(pred, Mul)
467
+ and all(arg == E or isinstance(arg, (Symbol)) for arg in pred.args)
468
+ ):
469
+ concat_pred = "".join(
470
+ arg.name if isinstance(arg, Symbol) else "e" for arg in pred.args
471
+ )
472
+ return gold.name.lower() == concat_pred.lower()
473
+
474
+ if (
475
+ isinstance(pred, Symbol)
476
+ and isinstance(gold, Mul)
477
+ and all(arg == E or isinstance(arg, (Symbol)) for arg in gold.args)
478
+ ):
479
+ concat_gold = "".join(
480
+ arg.name if isinstance(arg, Symbol) else "e" for arg in gold.args
481
+ )
482
+ return pred.name.lower() == concat_gold.lower()
483
+
484
+ # Simple
485
+ if isinstance(gold, Symbol) and isinstance(pred, Symbol):
486
+ g_name = gold.name
487
+ p_name = pred.name
488
+ if len(p_name) > 1:
489
+ p_name = p_name.lower()
490
+ if len(g_name) > 1:
491
+ g_name = g_name.lower()
492
+ return g_name == p_name
493
+
494
+ return str(gold) == str(pred)
495
+
496
+
497
+ def is_relation(expr: Basic | MatrixBase) -> bool:
498
+ """Check if an expression is a relational expression.
499
+
500
+ Args:
501
+ expr: The expression to check
502
+ Returns:
503
+ bool: True if expr is a relational expression or And of relations, False otherwise
504
+ """
505
+ if isinstance(expr, Relational):
506
+ return True
507
+
508
+ if isinstance(expr, And) and len(expr._unsorted_args) > 0:
509
+ return all(isinstance(arg, Relational) for arg in expr._unsorted_args)
510
+
511
+ return False
512
+
513
+
514
+ def is_equation(expr: Basic | MatrixBase) -> bool:
515
+ """Check if an expression is an equation.
516
+
517
+ Args:
518
+ expr: The expression to check
519
+ Returns:
520
+ bool: True if expr is an equation, False otherwise
521
+ """
522
+ if isinstance(expr, Eq):
523
+ return True
524
+
525
+ if isinstance(expr, And) and len(expr._unsorted_args) > 0:
526
+ return all(isinstance(arg, Eq) for arg in expr._unsorted_args)
527
+
528
+ return False
529
+
530
+
531
+ def is_assignment_relation(expr: Basic | MatrixBase) -> bool:
532
+ """Check if an expression is an assignment relation. E.g a=1
533
+
534
+ Args:
535
+ expr: The expression to check
536
+ Returns:
537
+ bool: True if expr is a relational expression or And of relations, False otherwise
538
+ """
539
+ if isinstance(expr, Eq) and is_expr_of_only_symbols(expr.lhs):
540
+ return True
541
+
542
+ if isinstance(expr, And) and len(expr._unsorted_args) > 0:
543
+ return all(
544
+ isinstance(arg, Eq) for arg in expr._unsorted_args
545
+ ) and is_expr_of_only_symbols(expr._unsorted_args[0].lhs)
546
+
547
+ return False
548
+
549
+
550
+ def take_last_relation(expr: And | Relational) -> Relational:
551
+ """Take the last relation from an And expression."""
552
+ if isinstance(expr, And):
553
+ return take_last_relation(expr._unsorted_args[-1])
554
+ return expr
555
+
556
+
557
+ def take_first_relation(expr: And | Relational) -> Relational:
558
+ """Take the first relation from an And expression."""
559
+ if isinstance(expr, And):
560
+ return expr._unsorted_args[0]
561
+ return expr
562
+
563
+
564
+ def unwrap_fcs(expr: Basic | MatrixBase) -> Basic | MatrixBase:
565
+ """Unwrap function calls to their arguments.
566
+
567
+ For example, Function('f')(x) becomes Symbol('f_x')
568
+
569
+ Args:
570
+ expr: The expression to unwrap
571
+
572
+ Returns:
573
+ The unwrapped expression with functions replaced by concatenated symbols
574
+ """
575
+ # Base case - not a Basic type
576
+ if not isinstance(expr, Basic):
577
+ return expr
578
+
579
+ # Handle function case
580
+ if hasattr(expr, "func") and isinstance(expr.func, UndefinedFunction):
581
+ # Get function name and arguments
582
+ func_name = expr.func.__name__
583
+ # Recursively unwrap arguments before converting to string
584
+ unwrapped_args = [str(unwrap_fcs(arg)) for arg in expr.args]
585
+ # Create new symbol by concatenating function name and args
586
+ return Symbol(f"{func_name}_{'_'.join(unwrapped_args)}")
587
+
588
+ # Recursively unwrap all arguments
589
+ try:
590
+ new_args = [unwrap_fcs(arg) for arg in expr.args]
591
+ if new_args:
592
+ return expr.func(*new_args)
593
+ except Exception:
594
+ pass
595
+
596
+ return expr
597
+
598
+
599
+ def sympy_expr_eq(
600
+ gold: Basic | MatrixBase,
601
+ pred: Basic | MatrixBase,
602
+ float_rounding: int,
603
+ numeric_precision: int,
604
+ allow_set_relation_comp: bool = False,
605
+ strict: bool = True,
606
+ ) -> bool:
607
+ """Compare two sympy expressions for equality using multiple methods.
608
+
609
+ Args:
610
+ gold: First sympy expression (expected)
611
+ pred: Second sympy expression (predicted)
612
+ precision: Number of decimal places to compare
613
+ allow_set_relation_comp: Whether to allow set - relation comparison. Defaults to False.
614
+ - If True, set - relation comparison will be allowed in all cases.
615
+ - If False, set - relation comparison will be allowed only if the prediction is a set.
616
+ strict: If true, variables do matter otherwise they don't
617
+
618
+ Returns:
619
+ True if expressions are equal by any comparison method, False otherwise
620
+ """
621
+
622
+ # This ensures that f(x) == f(y) is true
623
+ if not strict:
624
+ try:
625
+ gold_variables = gold.free_symbols
626
+ pred_variables = pred.free_symbols
627
+ if len(gold_variables) == len(pred_variables):
628
+ pred = pred.subs(
629
+ list(zip(pred_variables, gold_variables, strict=True))
630
+ )
631
+ except Exception:
632
+ pass
633
+
634
+ # If both are assigments, we don't want to unwrap them, so that x=1 != y=1
635
+ # But if one is assignment and other is equation, we want to unwrap both
636
+
637
+ # We always want to truncate if it's assignment, assignment
638
+
639
+ is_gold_assignment = is_assignment_relation(gold)
640
+ is_pred_assignment = is_assignment_relation(pred)
641
+ is_gold_equation = is_equation(gold)
642
+ is_pred_equation = is_equation(pred)
643
+
644
+ # Truncate equations chains in case of assignment, this doesn't change any of the above values,
645
+ # so no need to recompute them
646
+ if is_gold_assignment:
647
+ gold = Eq(
648
+ take_first_relation(gold).lhs, take_last_relation(gold).rhs, evaluate=False
649
+ )
650
+ if is_pred_assignment:
651
+ pred = Eq(
652
+ take_first_relation(pred).lhs, take_last_relation(pred).rhs, evaluate=False
653
+ )
654
+
655
+ # We follow what the gold format is
656
+ # 1 and 9=1 -> 1,1
657
+ if is_pred_equation and not is_gold_equation:
658
+ # Unwrap pred
659
+ pred = take_last_relation(pred).rhs
660
+
661
+ # We respect what the pred format is only if the gold is assignment so that x=1 and 1 -> 1,1, but not 2x + z = 1 and 1 -> 1,1
662
+ elif is_gold_assignment and not is_pred_equation:
663
+ gold = take_last_relation(gold).rhs
664
+
665
+ if is_relation(gold) and isinstance(pred, Set):
666
+ # This is to ensure that 1 < x < 2 equals (-oo, 1) U (2, oo)
667
+ # We also unwrap the functions because othewise it creates some conditional set based on the function name
668
+ try:
669
+ gold = unwrap_fcs(gold).as_set()
670
+ except Exception:
671
+ pass
672
+
673
+ if allow_set_relation_comp and is_relation(pred) and isinstance(gold, Set):
674
+ try:
675
+ pred = unwrap_fcs(pred).as_set()
676
+ except Exception:
677
+ pass
678
+
679
+ # Start with simple str and expr comparisson as it's the fastest
680
+ # str comparison is better, than simple eq, because it will also handle missarangments
681
+ if sympy_str_eq(gold, pred):
682
+ return True
683
+
684
+ # Support for equations
685
+ if is_relation(gold) and is_relation(pred):
686
+ return sympy_compare_relational(gold, pred, float_rounding, numeric_precision)
687
+
688
+ elif isinstance(gold, (Set, Tuple)) or isinstance(pred, (Set, Tuple)):
689
+ return sympy_compare_sets(gold, pred, float_rounding, numeric_precision)
690
+
691
+ # Handles $\text{answer}$ == $answer$, one is symbol, is multiplication of symbols (a*n*s*w*e*r)
692
+ elif isinstance(gold, Symbol) or isinstance(pred, Symbol):
693
+ return sympy_compare_symbols(gold, pred)
694
+
695
+ elif isinstance(gold, (Basic, MatrixBase)) and isinstance(
696
+ pred, (Basic, MatrixBase)
697
+ ):
698
+ # Mostly so that 0.333333 = 1/3
699
+ if sympy_numeric_eq(gold, pred, float_rounding, numeric_precision):
700
+ return True
701
+ # Then try symbolic equality
702
+ if sympy_symbolic_eq(gold, pred):
703
+ return True
704
+
705
+ return False
706
+
707
+
708
+ complex_number_pattern = re.compile(
709
+ r"""
710
+ # Complex number indicators
711
+ \\mathbb\{C\}| # Complex number set ℂ
712
+ \\i\b| # Complex i
713
+ \bi\b| # Standalone i
714
+ \\text\{i\}| # Text i
715
+ \\mathrm\{i\}| # Roman i
716
+ \\imath\b| # Alternative i notation
717
+
718
+ # Matrix operations
719
+ \\det| # Determinant
720
+ \\operatorname\{tr\}| # Trace
721
+ \\operatorname\{rank\}| # Rank
722
+ \\text\{rank\}|
723
+ \\arg\{| # Complex argument
724
+ \\Re\{| # Real part
725
+ \\Im\{| # Imaginary part
726
+ \\operatorname\{Re\}| # Real part alternate
727
+ \\operatorname\{Im\}| # Imaginary part alternate
728
+ \\text\{Re\}| # Real part text
729
+ \\text\{Im\} # Imaginary part text
730
+ """,
731
+ re.VERBOSE,
732
+ )
733
+
734
+
735
+ def should_treat_as_complex(latex_str: str) -> bool:
736
+ """
737
+ Returns True if the latex string likely contains complex numbers, matrices, or vectors.
738
+ """
739
+
740
+ return bool(complex_number_pattern.search(latex_str))
741
+
742
+
743
+ def verify(
744
+ gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
745
+ target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
746
+ float_rounding: int = 6,
747
+ numeric_precision: int = 15,
748
+ strict: bool = True,
749
+ allow_set_relation_comp: bool = False,
750
+ timeout_seconds: int | None = 5,
751
+ raise_on_error: bool = False,
752
+ ) -> bool:
753
+ """Verifies if the target expression matches the gold expression using multiple comparison strategies.
754
+
755
+ This function implements a comprehensive comparison system for mathematical expressions,
756
+ handling various types of mathematical objects (numbers, expressions, sets, matrices, etc.)
757
+ with multiple fallback strategies.
758
+
759
+ Note:
760
+ - It's expected that both gold and pred has been parsed with math_verify.parse function.
761
+ - Function is not symmetric, gold answer should be passed as gold and prediction as pred. The non-symmetric nature appears at assignment simplification and equation interval conversion.
762
+
763
+ Args:
764
+ gold: The reference/correct expression(s). Can be:
765
+ - A single SymPy expression (Basic or MatrixBase)
766
+ - A string
767
+ - A list of any of the above
768
+ target: The expression(s) to verify. Same types as gold.
769
+ float_rounding: Number of decimal places to round floats to. Defaults to 6.
770
+ numeric_precision: Number of decimal places to consider for numeric comparisons. Defaults to 15.
771
+ - If you know the evaluated expressions will be small, you should increase this. See: https://docs.sympy.org/latest/modules/evalf.html
772
+ strict: Whether to enforce strict comparison mode. Defaults to True.
773
+ - In strict mode: Variables matter and sets are not comparable with tuples
774
+ - In non-strict mode: Variables are matched by position and sets can be compared with tuples
775
+ timeout_seconds: Maximum time in seconds to spend on any single comparison operation.
776
+ Defaults to 5 seconds. Any timeout seconds > 0 or not None will result in the function to raise a ValueError if it's called in a threaded environment.
777
+ allow_set_relation_comp: Whether to allow set - relation (e.g 1 < x < 2 and (1, 2)) comparison. Defaults to False.
778
+ - If True, set - relation comparison will be allowed in all cases.
779
+ - If False, set - relation comparison will be allowed only if the prediction is a set.
780
+ raise_on_error: Whether to raise an exception if an error occurs during comparison or return False. Defaults to False.
781
+
782
+ Returns:
783
+ bool: True if target matches gold according to any of the comparison strategies,
784
+ False otherwise.
785
+
786
+ Comparison Strategy:
787
+ 1. String to String comparison
788
+ 2. Numeric expressions: Comparison within specified precision
789
+ 3. Symbolic equality through simplification
790
+ 4. Special handling for:
791
+ - Relational expressions (equations/inequalities)
792
+ - Sets and intervals
793
+ - Matrices and vectors
794
+ - Complex numbers
795
+ 5. Robust error handling with timeout protection
796
+
797
+ Example:
798
+ >>> verify(sympy.Rational(1, 3), 0.333333) # Numeric comparison
799
+ True
800
+ >>> verify(sympy.Symbol('x') + 1, sympy.Symbol('y') + 1, strict=False) # Variable matching
801
+ True
802
+ >>> verify(sympy.FiniteSet(1, 2), sympy.Tuple(1, 2), strict=False) # Set-tuple comparison
803
+ True
804
+ """
805
+
806
+ global TIMEOUT_WARNING_SHOWN
807
+ if not TIMEOUT_WARNING_SHOWN and (timeout_seconds is None or timeout_seconds <= 0):
808
+ logger.warning(
809
+ "Timeout is disabled as timeout_seconds is None or <= 0, you must provide \
810
+ the logic for timeout interuption yourself to prevent code getting stuck."
811
+ )
812
+ TIMEOUT_WARNING_SHOWN = True
813
+
814
+ @timeout(timeout_seconds=timeout_seconds)
815
+ def compare_single_extraction(
816
+ gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str
817
+ ) -> bool:
818
+ # If both are sympy expressions, we can use sympy to compare them
819
+ if isinstance(gold, (Basic, MatrixBase)) and isinstance(
820
+ target, (Basic, MatrixBase)
821
+ ):
822
+ return sympy_expr_eq(
823
+ gold, target, float_rounding, numeric_precision, allow_set_relation_comp, strict
824
+ )
825
+
826
+ # We don't support str / sympy.Expr comparison. Imo there is no point in doing this, as chances
827
+ # of this happening are very low. The only why one of them is not converted to sympy expression
828
+ # is usually because the parsing logic failed in this case we should improve the parsing logic
829
+ # instead of somehow fixing adhoc.
830
+ elif isinstance(gold, str) and isinstance(target, str):
831
+ # We just do string comparison for everything else
832
+ gold = gold.strip()
833
+ target = target.strip()
834
+
835
+ # Ensure it's both not empty and equal
836
+ return len(gold) > 0 and len(target) > 0 and gold == target
837
+
838
+ return False
839
+
840
+ def compare_single_extraction_wrapper(g, t):
841
+ try:
842
+ return compare_single_extraction(g, t)
843
+
844
+ except ValueError as e:
845
+ if str(e) == "signal only works in main thread of the main interpreter":
846
+ raise ValueError(
847
+ "Math-Verify doesn't support threaded environment due to usage of signal.alarm() in timeout mechanism. If you need to run in multithreaded environment it's recommended to set the parsing_timeout=None, which will run without timeout (and signal handling). In this case you need to handle the timeouting yourself."
848
+ ) from e
849
+ else:
850
+ if raise_on_error:
851
+ raise e from e
852
+ else:
853
+ logger.debug("Error during comparison", exc_info=True)
854
+ return False
855
+ except Exception as e:
856
+ #! Do not attempt to print out the g and t during handling of exception
857
+ # Because a) it can throw an exception itself and b) it can cause it to be stuck forever during str conversion
858
+ if raise_on_error:
859
+ raise e from e
860
+ else:
861
+ logger.debug("Error during comparison", exc_info=True)
862
+ return False
863
+ except TimeoutException as e:
864
+ if raise_on_error:
865
+ raise TimeoutException("Timeout during comparison") from e
866
+ else:
867
+ logger.warning("Timeout during comparison")
868
+ return False
869
+
870
+ if not isinstance(gold, list):
871
+ gold = [gold]
872
+ if not isinstance(target, list):
873
+ target = [target]
874
+
875
+ return any(
876
+ compare_single_extraction_wrapper(g, t) for g, t in product(gold, target)
877
+ )