AutoMathReasoner / env /verifier.py
HarshitShri026's picture
push
973cd6f
import re
import math
from typing import Dict, Any, Tuple
class VerifierSystem:
"""
Multi-stage verification system that returns graduated correctness scores
instead of binary pass/fail. This provides a dense reward signal for RL
training, enabling faster convergence.
Correctness tiers:
1.0 β€” Fully correct (exact or numerical match)
0.7 β€” Structurally correct (right form, wrong coefficient)
0.4 β€” Partially correct (correct technique identified)
0.15 β€” Minimal credit (parseable math expression attempted)
0.0 β€” Garbage / trivial output
References:
- DeepSeek-R1 GRPO reward design
- arxiv:2408.10215 (Reward Engineering for RL)
- arxiv:2601.19100 (Reward Engineering for Software Tasks)
"""
# Integration techniques and their associated keywords
TECHNIQUE_KEYWORDS = {
'u_substitution': ['substitut', 'u =', 'u=', 'let u', 'du'],
'by_parts': ['by parts', 'integration by parts', 'ibp', 'uv -', 'udv'],
'trig_sub': ['trig sub', 'trigonometric substitution', 'sin(ΞΈ)', 'cos(ΞΈ)', 'tan(ΞΈ)'],
'partial_fraction': ['partial fraction', 'decompos'],
'power_rule': ['power rule', 'x^n', 'x**'],
'exponential': ['exponential', 'e^', 'exp('],
'trigonometric': ['sin', 'cos', 'tan', 'sec', 'csc', 'cot'],
'logarithmic': ['ln', 'log', 'logarithm'],
}
# Mathematical reasoning markers for process supervision
MATH_MARKERS = [
'step', 'first', 'then', 'next', 'therefore', 'because', 'since',
'equals', 'simplif', 'substitut', 'evaluat', 'factor', 'expand',
'differentiat', 'integrat', 'apply', 'using', 'recall', 'note that',
'we get', 'we have', 'we know', 'this gives', 'which yields',
]
MATH_SYMBOLS = set('βˆ«βˆ‚βˆ‘βˆβˆšΒ±Γ—Γ·β‰ β‰€β‰₯β‰ˆβˆžβˆβˆˆβˆ‰βŠ‚βŠƒβˆ©βˆͺΞ±Ξ²Ξ³Ξ΄Ξ΅ΞΆΞ·ΞΈΞ»ΞΌΟ€ΟƒΟ†ΟˆΟ‰')
def __init__(self):
pass
def check_exact_match(self, prediction: str, ground_truth: str) -> bool:
"""1. Exact match verifier"""
return prediction.strip().lower() == ground_truth.strip().lower()
def check_numeric_tolerance(self, prediction: str, ground_truth: str, tol: float = 1e-4) -> bool:
"""2. Numeric tolerance checker"""
try:
pred_val = float(prediction.strip())
gt_val = float(ground_truth.strip())
return math.isclose(pred_val, gt_val, rel_tol=tol, abs_tol=tol)
except ValueError:
return False
def check_python_execution(self, prediction: str, ground_truth: str) -> bool:
"""3. Python execution (eval safe expressions)"""
# If prediction is an expression like "2+3", try evaluating it safely
safe_dict = {"__builtins__": None, "math": math}
try:
# We are verifying if evaluating the prediction gives ground truth
pred_eval = eval(prediction.strip(), safe_dict, {})
try:
gt_eval = float(ground_truth.strip())
return math.isclose(float(pred_eval), gt_eval, rel_tol=1e-4, abs_tol=1e-4)
except ValueError:
return str(pred_eval).strip().lower() == ground_truth.strip().lower()
except Exception:
return False
def check_numerical_integration(self, prediction: str, sympy_f: Any) -> bool:
"""
[PAPER TRACEABILITY: Section 3.1.3 Solution Verification]
Numerical multi-point quadrature verification.
Differentiates the prediction F_pred(x) and compares it to the ground
truth integrand f(x) at 5 random points.
"""
import sympy as sp
import random
x = sp.Symbol('x')
try:
clean_pred = self._clean_math_answer(prediction)
F_pred = sp.parse_expr(clean_pred)
f_pred = sp.diff(F_pred, x)
# Evaluate at 5 random points
for _ in range(5):
test_point = random.uniform(-5, 5)
p_val = float(f_pred.subs(x, test_point).evalf())
t_val = float(sympy_f.subs(x, test_point).evalf())
# Paper uses 10^-2 relative tolerance
if not math.isclose(p_val, t_val, rel_tol=1e-2, abs_tol=1e-2):
return False
return True
except Exception:
return False
def check_structural_similarity(self, prediction: str, ground_truth: str, sympy_f: Any = None) -> float:
"""
Graduated structural similarity check.
Compares SymPy expression trees to provide partial credit when the
model's answer has the right structure but wrong coefficients.
Returns:
0.7 if structure matches but coefficients differ
0.4 if the expression is parseable and shares operand types
0.15 if the prediction is a parseable math expression
0.0 if unparseable
"""
import sympy as sp
x = sp.Symbol('x')
try:
clean_pred = self._clean_math_answer(prediction)
clean_gt = self._clean_math_answer(ground_truth)
pred_expr = sp.parse_expr(clean_pred)
gt_expr = sp.parse_expr(clean_gt)
except Exception:
# Can't even parse β€” check if it at least looks like math
if self._looks_like_math(prediction):
return 0.15
return 0.0
# Check if expression trees have similar structure
try:
pred_funcs = self._extract_function_types(pred_expr)
gt_funcs = self._extract_function_types(gt_expr)
# Count overlapping function types (sin, cos, exp, log, Pow, etc.)
overlap = pred_funcs & gt_funcs
union = pred_funcs | gt_funcs
if not union:
return 0.15 # Both are just constants/variables
jaccard = len(overlap) / len(union)
if jaccard >= 0.8:
# Very similar structure β€” likely right form, wrong coefficient
# Verify by checking at sample points if shapes are proportional
if self._check_proportional(pred_expr, gt_expr, x):
return 0.7
return 0.5
elif jaccard >= 0.4:
return 0.4
else:
return 0.15
except Exception:
return 0.15
def check_technique_recognition(self, reasoning: str, technique_hint: str = "") -> float:
"""
Checks if the model identified the correct integration technique.
Returns a score ∈ [0, 1] based on technique match.
This provides reward signal even when the final answer is wrong,
as long as the model is using the right approach.
"""
if not technique_hint:
return 0.0
lower_r = reasoning.lower()
# Check if the correct technique keywords appear in reasoning
technique_kws = self.TECHNIQUE_KEYWORDS.get(technique_hint, [])
if not technique_kws:
return 0.0
matches = sum(1 for kw in technique_kws if kw in lower_r)
if matches >= 2:
return 1.0 # Strong evidence of correct technique
elif matches == 1:
return 0.6 # Some evidence
# Check if any technique was attempted at all
any_technique = False
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
if any(kw in lower_r for kw in kws):
any_technique = True
break
return 0.2 if any_technique else 0.0
def mock_llm_judge(self, reasoning: str, prediction: str, ground_truth: str) -> float:
"""4. LLM judge (mock or placeholder scoring reasoning quality)
Returns reasoning quality score Q (0.0 to 1.0)
Improved with mathematical density scoring and better structural analysis.
"""
score = 0.0
lower_reasoning = reasoning.lower()
words = reasoning.split()
length = len(words)
# Length bonus (up to 0.25) β€” diminishing returns, gentle curve
score += min(0.25, length * 0.005)
# Mathematical marker bonus (up to 0.35)
marker_count = sum(1 for m in self.MATH_MARKERS if m in lower_reasoning)
score += min(0.35, marker_count * 0.05)
# Mathematical symbol density bonus (up to 0.2)
math_chars = sum(1 for c in reasoning if c in '=+-*/^()βˆ«βˆ‚βˆ‘βˆš' or c in self.MATH_SYMBOLS)
if length > 0:
math_density = math_chars / max(1, len(reasoning))
score += min(0.2, math_density * 2.0)
# Structured step progression bonus (up to 0.2)
has_numbered_steps = bool(re.search(r'step\s*\d|^\d+[\.\)]', lower_reasoning, re.MULTILINE))
has_logical_flow = ('therefore' in lower_reasoning or 'thus' in lower_reasoning or
'hence' in lower_reasoning or 'so we' in lower_reasoning)
if has_numbered_steps:
score += 0.12
if has_logical_flow:
score += 0.08
return round(min(1.0, score), 3)
def check_process_supervision(self, reasoning: str) -> float:
"""
[PAPER TRACEABILITY: Process Supervision (Lightweight PRM)]
E. PROCESS SUPERVISION (STEP-AWARE REWARD)
Improved with:
- Mathematical density scoring
- Multi-level step detection
- Granular logical jump penalties
- Technique-specific reward signals
"""
lower_r = reasoning.lower()
words = lower_r.split()
word_count = len(words)
score = 0.0
# 1. Check stepwise structure (up to 0.4)
numbered_steps = len(re.findall(r'step\s*\d', lower_r))
if numbered_steps >= 3:
score += 0.4
elif numbered_steps >= 2:
score += 0.3
elif numbered_steps >= 1:
score += 0.2
elif 'first' in lower_r and ('then' in lower_r or 'next' in lower_r):
score += 0.15
# 2. Mathematical operation density (up to 0.3)
math_ops = len(re.findall(r'[=+\-*/^]', reasoning))
if word_count > 0:
op_density = math_ops / word_count
score += min(0.3, op_density * 3.0)
# 3. Technique identification bonus (up to 0.2)
techniques_mentioned = 0
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
if any(kw in lower_r for kw in kws):
techniques_mentioned += 1
score += min(0.2, techniques_mentioned * 0.1)
# 4. Logical jump penalty β€” short reasoning with complex claims
if word_count < 10 and ('=' in lower_r or 'so' in lower_r):
score -= 0.3
elif word_count < 20 and math_ops > 3:
score -= 0.15 # Slightly suspicious β€” many operations, few words
# 5. Bonus for showing intermediate results
intermediate_results = len(re.findall(r'=\s*[\d\w]', reasoning))
score += min(0.1, intermediate_results * 0.02)
return max(-1.0, min(1.0, score))
def check_reflection(self, reasoning: str, c: float) -> float:
"""
[PAPER TRACEABILITY: Reflection Module]
H. REFLECTION MODULE
Model generates "What could be wrong?"
Penalize if contradiction with final answer, reward correct self-correction.
Improved with graduated scoring based on reflection quality.
"""
lower_r = reasoning.lower()
score = 0.0
reflection_phrases = [
"what could be wrong", "wait,", "let me check", "alternatively",
"let me verify", "double check", "reconsider", "hmm",
"actually,", "correction:", "i made an error", "let me redo"
]
reflections_found = sum(1 for phrase in reflection_phrases if phrase in lower_r)
if reflections_found > 0:
if c >= 0.7: # At least partially correct
# Graduated reward based on how many reflection markers used
score += min(1.0, 0.5 + reflections_found * 0.2)
elif c >= 0.4:
# Some credit β€” reflected but didn't fully fix
score += 0.1
else:
# Reflected but still wrong β€” mild penalty (not as harsh as before)
score -= 0.3
return max(-1.0, min(1.0, score))
def verify(self, reasoning: str, prediction: str, ground_truth: str,
sympy_f: Any = None, technique_hint: str = "") -> Tuple[float, float, float, float]:
"""
Run all verifiers with GRADUATED CORRECTNESS scoring.
Returns:
C β€” Correctness ∈ [0, 1] (graduated, not binary)
Q β€” Reasoning Quality ∈ [0, 1]
P β€” Process Supervision ∈ [-1, 1]
R β€” Reflection Score ∈ [-1, 1]
"""
# --- Graduated Correctness ---
c = 0.0
# Tier 1: Full correctness (1.0)
if self.check_exact_match(prediction, ground_truth):
c = 1.0
elif sympy_f is not None and self.check_numerical_integration(prediction, sympy_f):
c = 1.0
elif self.check_numeric_tolerance(prediction, ground_truth):
c = 1.0
elif self.check_python_execution(prediction, ground_truth):
c = 1.0
# Tier 2-4: Partial credit (only if not fully correct)
if c < 1.0:
structural_score = self.check_structural_similarity(prediction, ground_truth, sympy_f)
technique_score = self.check_technique_recognition(reasoning, technique_hint)
# Take the best partial credit signal
c = max(c, structural_score)
# Technique recognition can boost partial credit
if technique_score > 0 and c < 0.7:
c = max(c, 0.15 + technique_score * 0.25) # Up to 0.4 from technique alone
q = self.mock_llm_judge(reasoning, prediction, ground_truth)
p = self.check_process_supervision(reasoning)
r = self.check_reflection(reasoning, c)
return c, q, p, r
# --- Private Helpers ---
def _clean_math_answer(self, text: str) -> str:
"""Clean a math answer string for SymPy parsing."""
clean = text.strip()
if "Answer:" in clean:
clean = clean.split("Answer:")[-1].strip()
# Remove constant of integration
clean = re.sub(r'\+\s*[Cc]\s*$', '', clean).strip()
# Remove LaTeX wrappers
clean = clean.replace('$', '').replace('\\', '')
return clean
def _looks_like_math(self, text: str) -> bool:
"""Check if text contains mathematical content."""
math_indicators = ['=', '+', '-', '*', '/', '^', 'x', 'sin', 'cos', 'exp', 'log', '(']
return sum(1 for m in math_indicators if m in text.lower()) >= 2
def _extract_function_types(self, expr) -> set:
"""Extract the set of function types from a SymPy expression tree."""
import sympy as sp
types = set()
if isinstance(expr, sp.Add):
types.add('Add')
elif isinstance(expr, sp.Mul):
types.add('Mul')
elif isinstance(expr, sp.Pow):
types.add('Pow')
func_type = type(expr).__name__
if func_type in ('sin', 'cos', 'tan', 'exp', 'log', 'ln', 'Abs',
'asin', 'acos', 'atan', 'sinh', 'cosh', 'tanh'):
types.add(func_type)
# Recurse into sub-expressions
if hasattr(expr, 'args'):
for arg in expr.args:
types |= self._extract_function_types(arg)
return types
def _check_proportional(self, expr1, expr2, x) -> bool:
"""Check if two expressions are proportional (differ only by a constant factor)."""
import sympy as sp
import random
try:
ratios = []
for _ in range(3):
pt = random.uniform(-3, 3)
v1 = float(expr1.subs(x, pt).evalf())
v2 = float(expr2.subs(x, pt).evalf())
if abs(v2) < 1e-10:
continue
ratios.append(v1 / v2)
if len(ratios) < 2:
return False
# Check if all ratios are approximately equal (constant factor)
return all(math.isclose(r, ratios[0], rel_tol=0.1) for r in ratios)
except Exception:
return False