mvm2-math-verification / verification_service.py
Varshithdharmajv's picture
Upload verification_service.py with huggingface_hub
11f174e verified
import re
from typing import List, Dict, Any
from sympy import sympify, simplify, Eq, parse_expr
try:
from math_verify import parse, verify
MATH_VERIFY_AVAILABLE = True
except ImportError:
MATH_VERIFY_AVAILABLE = False
def extract_equations(text: str) -> List[str]:
"""Extracts mathematical equations or expressions from a reasoning step."""
patterns = [
r'(\$.*?\$)',
r'(\\\[.*?\\\])',
r'([a-zA-Z0-9\(\)\+\-\*\/]+ *= *[a-zA-Z0-9\(\)\+\-\*\/]+)'
]
found = []
for pattern in patterns:
matches = re.findall(pattern, text)
for m in matches:
clean = m.replace('$', '').replace('\\[', '').replace('\\]', '').strip()
if '=' in clean:
found.append(clean)
if not found:
lines = text.split('\n')
for line in lines:
if "=" in line and sum(c.isalpha() for c in line) < len(line) / 2:
found.append(line.strip())
return found
def check_logical_progression(step_n: str, step_n_plus_1: str) -> bool:
"""
Implements the SymPy Validation function \vartheta(r_{jl}).
"""
eqs_n = extract_equations(step_n)
eqs_n_plus_1 = extract_equations(step_n_plus_1)
if not eqs_n or not eqs_n_plus_1:
return True
try:
for eq1 in eqs_n:
for eq2 in eqs_n_plus_1:
if re.search(r'(\d+) *= *(?!\1)(\d+)', eq2):
return False
if '=' in eq1 and '=' in eq2:
lhs1, rhs1 = eq1.split('=', 1)
lhs2, rhs2 = eq2.split('=', 1)
if MATH_VERIFY_AVAILABLE:
try:
matched = verify(parse(eq1), parse(eq2))
except (ValueError, Exception):
matched = (eq1.strip() == eq2.strip())
if matched:
return True
else:
expr1 = parse_expr(lhs1.replace('^', '**')) - parse_expr(rhs1.replace('^', '**'))
expr2 = parse_expr(lhs2.replace('^', '**')) - parse_expr(rhs2.replace('^', '**'))
if simplify(expr1) == simplify(expr2) or simplify(expr1 + expr2) == 0:
return True
except Exception:
pass
if re.search(r'\b(\d+)\s*=\s*(?!\1)(\d+)\b', step_n_plus_1):
return False
return True
def calculate_symbolic_score(reasoning_trace: List[str]) -> float:
"""
Calculates V^{sym}_j based on the logical sequence of steps.
"""
if not reasoning_trace:
return 0.0
if len(reasoning_trace) <= 1:
return 1.0
valid_transitions = 0
total_transitions = len(reasoning_trace) - 1
for i in range(total_transitions):
is_valid = check_logical_progression(reasoning_trace[i], reasoning_trace[i+1])
if is_valid:
valid_transitions += 1
v_sym = float(valid_transitions) / float(total_transitions)
for step in reasoning_trace:
if not check_logical_progression("", step):
v_sym *= 0.5
break
return round(v_sym, 2)