|
|
""" |
|
|
Validation module for the Math Expert model |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
import hashlib |
|
|
import datetime |
|
|
from typing import Dict, Any, List, Optional |
|
|
import numpy as np |
|
|
from sympy import simplify, Eq |
|
|
|
|
|
class MathValidator: |
|
|
def __init__(self, checkpoint_dir: str = "checkpoints"): |
|
|
self.checkpoint_dir = Path(checkpoint_dir) |
|
|
self.checkpoint_dir.mkdir(exist_ok=True) |
|
|
self.validation_dir = self.checkpoint_dir / "validation" |
|
|
self.validation_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
self.metrics = { |
|
|
"accuracy": [], |
|
|
"equation_simplification": [], |
|
|
"proof_validation": [], |
|
|
"memory_usage": [] |
|
|
} |
|
|
|
|
|
def validate_equation(self, equation: str, expected_result: str) -> Dict[str, Any]: |
|
|
"""Validate mathematical equation correctness""" |
|
|
try: |
|
|
|
|
|
lhs = simplify(equation) |
|
|
rhs = simplify(expected_result) |
|
|
|
|
|
|
|
|
is_correct = lhs == rhs |
|
|
|
|
|
return { |
|
|
"is_correct": is_correct, |
|
|
"simplified_lhs": str(lhs), |
|
|
"simplified_rhs": str(rhs), |
|
|
"validation_score": float(is_correct) |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"is_correct": False, |
|
|
"error": str(e), |
|
|
"validation_score": 0.0 |
|
|
} |
|
|
|
|
|
def validate_proof(self, proof_steps: List[str], expected_theorem: str) -> Dict[str, Any]: |
|
|
"""Validate mathematical proof steps""" |
|
|
try: |
|
|
|
|
|
current_context = set() |
|
|
validation_score = 1.0 |
|
|
|
|
|
for step in proof_steps: |
|
|
|
|
|
try: |
|
|
lhs, rhs = step.split('=') |
|
|
if not Eq(simplify(lhs), simplify(rhs)): |
|
|
validation_score *= 0.9 |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
current_context.add(step) |
|
|
|
|
|
|
|
|
final_step = proof_steps[-1] |
|
|
matches_theorem = expected_theorem in final_step |
|
|
|
|
|
return { |
|
|
"is_valid": validation_score > 0.5, |
|
|
"validation_score": validation_score, |
|
|
"matches_theorem": matches_theorem, |
|
|
"context_size": len(current_context) |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"is_valid": False, |
|
|
"error": str(e), |
|
|
"validation_score": 0.0 |
|
|
} |
|
|
|
|
|
def create_checkpoint(self, data: Dict[str, Any], name: str = None) -> str: |
|
|
"""Create a checkpoint of validation data""" |
|
|
if name is None: |
|
|
name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
checkpoint_path = self.validation_dir / f"checkpoint_{name}.json" |
|
|
|
|
|
|
|
|
data["timestamp"] = str(datetime.datetime.now()) |
|
|
data["hash"] = hashlib.sha256(str(data).encode()).hexdigest() |
|
|
|
|
|
with open(checkpoint_path, 'w') as f: |
|
|
json.dump(data, f, indent=2) |
|
|
|
|
|
return str(checkpoint_path) |
|
|
|
|
|
def load_checkpoint(self, name: str) -> Optional[Dict[str, Any]]: |
|
|
"""Load a validation checkpoint""" |
|
|
checkpoint_path = self.validation_dir / f"checkpoint_{name}.json" |
|
|
if not checkpoint_path.exists(): |
|
|
return None |
|
|
|
|
|
with open(checkpoint_path, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
"""Validate a complete dataset""" |
|
|
results = [] |
|
|
error_count = 0 |
|
|
|
|
|
for idx, example in enumerate(dataset): |
|
|
try: |
|
|
|
|
|
if "equation" in example: |
|
|
eq_result = self.validate_equation( |
|
|
example["equation"], |
|
|
example.get("expected_result", "") |
|
|
) |
|
|
results.append(eq_result) |
|
|
|
|
|
|
|
|
if "proof_steps" in example: |
|
|
proof_result = self.validate_proof( |
|
|
example["proof_steps"], |
|
|
example.get("theorem", "") |
|
|
) |
|
|
results.append(proof_result) |
|
|
except Exception as e: |
|
|
error_count += 1 |
|
|
results.append({ |
|
|
"error": str(e), |
|
|
"validation_score": 0.0 |
|
|
}) |
|
|
|
|
|
|
|
|
scores = [r["validation_score"] for r in results if "validation_score" in r] |
|
|
if scores: |
|
|
avg_score = np.mean(scores) |
|
|
else: |
|
|
avg_score = 0.0 |
|
|
|
|
|
return { |
|
|
"total_examples": len(dataset), |
|
|
"processed_examples": len(results), |
|
|
"error_count": error_count, |
|
|
"average_score": float(avg_score), |
|
|
"detailed_results": results |
|
|
} |
|
|
|
|
|
def save_validation_report(self, report: Dict[str, Any], name: str = None) -> str: |
|
|
"""Save a validation report""" |
|
|
if name is None: |
|
|
name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
report_path = self.validation_dir / f"report_{name}.json" |
|
|
|
|
|
|
|
|
report["timestamp"] = str(datetime.datetime.now()) |
|
|
report["summary"] = { |
|
|
"accuracy": report.get("average_score", 0.0), |
|
|
"error_rate": report.get("error_count", 0) / report.get("total_examples", 1) |
|
|
} |
|
|
|
|
|
with open(report_path, 'w') as f: |
|
|
json.dump(report, f, indent=2) |
|
|
|
|
|
return str(report_path) |
|
|
|