Spaces:
Sleeping
Sleeping
File size: 17,009 Bytes
f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f f8319a8 973cd6f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 | import re
import math
from typing import Dict, Any, Tuple
class VerifierSystem:
"""
Multi-stage verification system that returns graduated correctness scores
instead of binary pass/fail. This provides a dense reward signal for RL
training, enabling faster convergence.
Correctness tiers:
1.0 β Fully correct (exact or numerical match)
0.7 β Structurally correct (right form, wrong coefficient)
0.4 β Partially correct (correct technique identified)
0.15 β Minimal credit (parseable math expression attempted)
0.0 β Garbage / trivial output
References:
- DeepSeek-R1 GRPO reward design
- arxiv:2408.10215 (Reward Engineering for RL)
- arxiv:2601.19100 (Reward Engineering for Software Tasks)
"""
# Integration techniques and their associated keywords
TECHNIQUE_KEYWORDS = {
'u_substitution': ['substitut', 'u =', 'u=', 'let u', 'du'],
'by_parts': ['by parts', 'integration by parts', 'ibp', 'uv -', 'udv'],
'trig_sub': ['trig sub', 'trigonometric substitution', 'sin(ΞΈ)', 'cos(ΞΈ)', 'tan(ΞΈ)'],
'partial_fraction': ['partial fraction', 'decompos'],
'power_rule': ['power rule', 'x^n', 'x**'],
'exponential': ['exponential', 'e^', 'exp('],
'trigonometric': ['sin', 'cos', 'tan', 'sec', 'csc', 'cot'],
'logarithmic': ['ln', 'log', 'logarithm'],
}
# Mathematical reasoning markers for process supervision
MATH_MARKERS = [
'step', 'first', 'then', 'next', 'therefore', 'because', 'since',
'equals', 'simplif', 'substitut', 'evaluat', 'factor', 'expand',
'differentiat', 'integrat', 'apply', 'using', 'recall', 'note that',
'we get', 'we have', 'we know', 'this gives', 'which yields',
]
MATH_SYMBOLS = set('β«ββββΒ±ΓΓ·β β€β₯ββββββββ©βͺαβγδΡ΢ηθλμΟΟΟΟΟ')
def __init__(self):
pass
def check_exact_match(self, prediction: str, ground_truth: str) -> bool:
"""1. Exact match verifier"""
return prediction.strip().lower() == ground_truth.strip().lower()
def check_numeric_tolerance(self, prediction: str, ground_truth: str, tol: float = 1e-4) -> bool:
"""2. Numeric tolerance checker"""
try:
pred_val = float(prediction.strip())
gt_val = float(ground_truth.strip())
return math.isclose(pred_val, gt_val, rel_tol=tol, abs_tol=tol)
except ValueError:
return False
def check_python_execution(self, prediction: str, ground_truth: str) -> bool:
"""3. Python execution (eval safe expressions)"""
# If prediction is an expression like "2+3", try evaluating it safely
safe_dict = {"__builtins__": None, "math": math}
try:
# We are verifying if evaluating the prediction gives ground truth
pred_eval = eval(prediction.strip(), safe_dict, {})
try:
gt_eval = float(ground_truth.strip())
return math.isclose(float(pred_eval), gt_eval, rel_tol=1e-4, abs_tol=1e-4)
except ValueError:
return str(pred_eval).strip().lower() == ground_truth.strip().lower()
except Exception:
return False
def check_numerical_integration(self, prediction: str, sympy_f: Any) -> bool:
"""
[PAPER TRACEABILITY: Section 3.1.3 Solution Verification]
Numerical multi-point quadrature verification.
Differentiates the prediction F_pred(x) and compares it to the ground
truth integrand f(x) at 5 random points.
"""
import sympy as sp
import random
x = sp.Symbol('x')
try:
clean_pred = self._clean_math_answer(prediction)
F_pred = sp.parse_expr(clean_pred)
f_pred = sp.diff(F_pred, x)
# Evaluate at 5 random points
for _ in range(5):
test_point = random.uniform(-5, 5)
p_val = float(f_pred.subs(x, test_point).evalf())
t_val = float(sympy_f.subs(x, test_point).evalf())
# Paper uses 10^-2 relative tolerance
if not math.isclose(p_val, t_val, rel_tol=1e-2, abs_tol=1e-2):
return False
return True
except Exception:
return False
def check_structural_similarity(self, prediction: str, ground_truth: str, sympy_f: Any = None) -> float:
"""
Graduated structural similarity check.
Compares SymPy expression trees to provide partial credit when the
model's answer has the right structure but wrong coefficients.
Returns:
0.7 if structure matches but coefficients differ
0.4 if the expression is parseable and shares operand types
0.15 if the prediction is a parseable math expression
0.0 if unparseable
"""
import sympy as sp
x = sp.Symbol('x')
try:
clean_pred = self._clean_math_answer(prediction)
clean_gt = self._clean_math_answer(ground_truth)
pred_expr = sp.parse_expr(clean_pred)
gt_expr = sp.parse_expr(clean_gt)
except Exception:
# Can't even parse β check if it at least looks like math
if self._looks_like_math(prediction):
return 0.15
return 0.0
# Check if expression trees have similar structure
try:
pred_funcs = self._extract_function_types(pred_expr)
gt_funcs = self._extract_function_types(gt_expr)
# Count overlapping function types (sin, cos, exp, log, Pow, etc.)
overlap = pred_funcs & gt_funcs
union = pred_funcs | gt_funcs
if not union:
return 0.15 # Both are just constants/variables
jaccard = len(overlap) / len(union)
if jaccard >= 0.8:
# Very similar structure β likely right form, wrong coefficient
# Verify by checking at sample points if shapes are proportional
if self._check_proportional(pred_expr, gt_expr, x):
return 0.7
return 0.5
elif jaccard >= 0.4:
return 0.4
else:
return 0.15
except Exception:
return 0.15
def check_technique_recognition(self, reasoning: str, technique_hint: str = "") -> float:
"""
Checks if the model identified the correct integration technique.
Returns a score β [0, 1] based on technique match.
This provides reward signal even when the final answer is wrong,
as long as the model is using the right approach.
"""
if not technique_hint:
return 0.0
lower_r = reasoning.lower()
# Check if the correct technique keywords appear in reasoning
technique_kws = self.TECHNIQUE_KEYWORDS.get(technique_hint, [])
if not technique_kws:
return 0.0
matches = sum(1 for kw in technique_kws if kw in lower_r)
if matches >= 2:
return 1.0 # Strong evidence of correct technique
elif matches == 1:
return 0.6 # Some evidence
# Check if any technique was attempted at all
any_technique = False
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
if any(kw in lower_r for kw in kws):
any_technique = True
break
return 0.2 if any_technique else 0.0
def mock_llm_judge(self, reasoning: str, prediction: str, ground_truth: str) -> float:
"""4. LLM judge (mock or placeholder scoring reasoning quality)
Returns reasoning quality score Q (0.0 to 1.0)
Improved with mathematical density scoring and better structural analysis.
"""
score = 0.0
lower_reasoning = reasoning.lower()
words = reasoning.split()
length = len(words)
# Length bonus (up to 0.25) β diminishing returns, gentle curve
score += min(0.25, length * 0.005)
# Mathematical marker bonus (up to 0.35)
marker_count = sum(1 for m in self.MATH_MARKERS if m in lower_reasoning)
score += min(0.35, marker_count * 0.05)
# Mathematical symbol density bonus (up to 0.2)
math_chars = sum(1 for c in reasoning if c in '=+-*/^()β«βββ' or c in self.MATH_SYMBOLS)
if length > 0:
math_density = math_chars / max(1, len(reasoning))
score += min(0.2, math_density * 2.0)
# Structured step progression bonus (up to 0.2)
has_numbered_steps = bool(re.search(r'step\s*\d|^\d+[\.\)]', lower_reasoning, re.MULTILINE))
has_logical_flow = ('therefore' in lower_reasoning or 'thus' in lower_reasoning or
'hence' in lower_reasoning or 'so we' in lower_reasoning)
if has_numbered_steps:
score += 0.12
if has_logical_flow:
score += 0.08
return round(min(1.0, score), 3)
def check_process_supervision(self, reasoning: str) -> float:
"""
[PAPER TRACEABILITY: Process Supervision (Lightweight PRM)]
E. PROCESS SUPERVISION (STEP-AWARE REWARD)
Improved with:
- Mathematical density scoring
- Multi-level step detection
- Granular logical jump penalties
- Technique-specific reward signals
"""
lower_r = reasoning.lower()
words = lower_r.split()
word_count = len(words)
score = 0.0
# 1. Check stepwise structure (up to 0.4)
numbered_steps = len(re.findall(r'step\s*\d', lower_r))
if numbered_steps >= 3:
score += 0.4
elif numbered_steps >= 2:
score += 0.3
elif numbered_steps >= 1:
score += 0.2
elif 'first' in lower_r and ('then' in lower_r or 'next' in lower_r):
score += 0.15
# 2. Mathematical operation density (up to 0.3)
math_ops = len(re.findall(r'[=+\-*/^]', reasoning))
if word_count > 0:
op_density = math_ops / word_count
score += min(0.3, op_density * 3.0)
# 3. Technique identification bonus (up to 0.2)
techniques_mentioned = 0
for tech, kws in self.TECHNIQUE_KEYWORDS.items():
if any(kw in lower_r for kw in kws):
techniques_mentioned += 1
score += min(0.2, techniques_mentioned * 0.1)
# 4. Logical jump penalty β short reasoning with complex claims
if word_count < 10 and ('=' in lower_r or 'so' in lower_r):
score -= 0.3
elif word_count < 20 and math_ops > 3:
score -= 0.15 # Slightly suspicious β many operations, few words
# 5. Bonus for showing intermediate results
intermediate_results = len(re.findall(r'=\s*[\d\w]', reasoning))
score += min(0.1, intermediate_results * 0.02)
return max(-1.0, min(1.0, score))
def check_reflection(self, reasoning: str, c: float) -> float:
"""
[PAPER TRACEABILITY: Reflection Module]
H. REFLECTION MODULE
Model generates "What could be wrong?"
Penalize if contradiction with final answer, reward correct self-correction.
Improved with graduated scoring based on reflection quality.
"""
lower_r = reasoning.lower()
score = 0.0
reflection_phrases = [
"what could be wrong", "wait,", "let me check", "alternatively",
"let me verify", "double check", "reconsider", "hmm",
"actually,", "correction:", "i made an error", "let me redo"
]
reflections_found = sum(1 for phrase in reflection_phrases if phrase in lower_r)
if reflections_found > 0:
if c >= 0.7: # At least partially correct
# Graduated reward based on how many reflection markers used
score += min(1.0, 0.5 + reflections_found * 0.2)
elif c >= 0.4:
# Some credit β reflected but didn't fully fix
score += 0.1
else:
# Reflected but still wrong β mild penalty (not as harsh as before)
score -= 0.3
return max(-1.0, min(1.0, score))
def verify(self, reasoning: str, prediction: str, ground_truth: str,
sympy_f: Any = None, technique_hint: str = "") -> Tuple[float, float, float, float]:
"""
Run all verifiers with GRADUATED CORRECTNESS scoring.
Returns:
C β Correctness β [0, 1] (graduated, not binary)
Q β Reasoning Quality β [0, 1]
P β Process Supervision β [-1, 1]
R β Reflection Score β [-1, 1]
"""
# --- Graduated Correctness ---
c = 0.0
# Tier 1: Full correctness (1.0)
if self.check_exact_match(prediction, ground_truth):
c = 1.0
elif sympy_f is not None and self.check_numerical_integration(prediction, sympy_f):
c = 1.0
elif self.check_numeric_tolerance(prediction, ground_truth):
c = 1.0
elif self.check_python_execution(prediction, ground_truth):
c = 1.0
# Tier 2-4: Partial credit (only if not fully correct)
if c < 1.0:
structural_score = self.check_structural_similarity(prediction, ground_truth, sympy_f)
technique_score = self.check_technique_recognition(reasoning, technique_hint)
# Take the best partial credit signal
c = max(c, structural_score)
# Technique recognition can boost partial credit
if technique_score > 0 and c < 0.7:
c = max(c, 0.15 + technique_score * 0.25) # Up to 0.4 from technique alone
q = self.mock_llm_judge(reasoning, prediction, ground_truth)
p = self.check_process_supervision(reasoning)
r = self.check_reflection(reasoning, c)
return c, q, p, r
# --- Private Helpers ---
def _clean_math_answer(self, text: str) -> str:
"""Clean a math answer string for SymPy parsing."""
clean = text.strip()
if "Answer:" in clean:
clean = clean.split("Answer:")[-1].strip()
# Remove constant of integration
clean = re.sub(r'\+\s*[Cc]\s*$', '', clean).strip()
# Remove LaTeX wrappers
clean = clean.replace('$', '').replace('\\', '')
return clean
def _looks_like_math(self, text: str) -> bool:
"""Check if text contains mathematical content."""
math_indicators = ['=', '+', '-', '*', '/', '^', 'x', 'sin', 'cos', 'exp', 'log', '(']
return sum(1 for m in math_indicators if m in text.lower()) >= 2
def _extract_function_types(self, expr) -> set:
"""Extract the set of function types from a SymPy expression tree."""
import sympy as sp
types = set()
if isinstance(expr, sp.Add):
types.add('Add')
elif isinstance(expr, sp.Mul):
types.add('Mul')
elif isinstance(expr, sp.Pow):
types.add('Pow')
func_type = type(expr).__name__
if func_type in ('sin', 'cos', 'tan', 'exp', 'log', 'ln', 'Abs',
'asin', 'acos', 'atan', 'sinh', 'cosh', 'tanh'):
types.add(func_type)
# Recurse into sub-expressions
if hasattr(expr, 'args'):
for arg in expr.args:
types |= self._extract_function_types(arg)
return types
def _check_proportional(self, expr1, expr2, x) -> bool:
"""Check if two expressions are proportional (differ only by a constant factor)."""
import sympy as sp
import random
try:
ratios = []
for _ in range(3):
pt = random.uniform(-3, 3)
v1 = float(expr1.subs(x, pt).evalf())
v2 = float(expr2.subs(x, pt).evalf())
if abs(v2) < 1e-10:
continue
ratios.append(v1 / v2)
if len(ratios) < 2:
return False
# Check if all ratios are approximately equal (constant factor)
return all(math.isclose(r, ratios[0], rel_tol=0.1) for r in ratios)
except Exception:
return False
|