File size: 6,182 Bytes
b8ab4a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Validation module for the Math Expert model
"""
import os
import json
from pathlib import Path
import hashlib
import datetime
from typing import Dict, Any, List, Optional
import numpy as np
from sympy import simplify, Eq

class MathValidator:
    def __init__(self, checkpoint_dir: str = "checkpoints"):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.validation_dir = self.checkpoint_dir / "validation"
        self.validation_dir.mkdir(exist_ok=True)
        
        # Initialize validation metrics
        self.metrics = {
            "accuracy": [],
            "equation_simplification": [],
            "proof_validation": [],
            "memory_usage": []
        }

    def validate_equation(self, equation: str, expected_result: str) -> Dict[str, Any]:
        """Validate mathematical equation correctness"""
        try:
            # Try to simplify both sides
            lhs = simplify(equation)
            rhs = simplify(expected_result)
            
            # Check if simplified forms are equal
            is_correct = lhs == rhs
            
            return {
                "is_correct": is_correct,
                "simplified_lhs": str(lhs),
                "simplified_rhs": str(rhs),
                "validation_score": float(is_correct)
            }
        except Exception as e:
            return {
                "is_correct": False,
                "error": str(e),
                "validation_score": 0.0
            }

    def validate_proof(self, proof_steps: List[str], expected_theorem: str) -> Dict[str, Any]:
        """Validate mathematical proof steps"""
        try:
            # Check if each step logically follows from previous steps
            current_context = set()
            validation_score = 1.0
            
            for step in proof_steps:
                # Try to parse the step as an equation
                try:
                    lhs, rhs = step.split('=')
                    if not Eq(simplify(lhs), simplify(rhs)):
                        validation_score *= 0.9  # Penalize incorrect steps
                except:
                    pass  # Not all steps are equations
                    
                # Update context
                current_context.add(step)
                
            # Check if final step matches expected theorem
            final_step = proof_steps[-1]
            matches_theorem = expected_theorem in final_step
            
            return {
                "is_valid": validation_score > 0.5,
                "validation_score": validation_score,
                "matches_theorem": matches_theorem,
                "context_size": len(current_context)
            }
        except Exception as e:
            return {
                "is_valid": False,
                "error": str(e),
                "validation_score": 0.0
            }

    def create_checkpoint(self, data: Dict[str, Any], name: str = None) -> str:
        """Create a checkpoint of validation data"""
        if name is None:
            name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        
        checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
        
        # Add timestamp and hash
        data["timestamp"] = str(datetime.datetime.now())
        data["hash"] = hashlib.sha256(str(data).encode()).hexdigest()
        
        with open(checkpoint_path, 'w') as f:
            json.dump(data, f, indent=2)
        
        return str(checkpoint_path)

    def load_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
        """Load a validation checkpoint"""
        checkpoint_path = self.validation_dir / f"checkpoint_{name}.json"
        if not checkpoint_path.exists():
            return None
            
        with open(checkpoint_path, 'r') as f:
            return json.load(f)

    def validate_dataset(self, dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Validate a complete dataset"""
        results = []
        error_count = 0
        
        for idx, example in enumerate(dataset):
            try:
                # Validate equations
                if "equation" in example:
                    eq_result = self.validate_equation(
                        example["equation"],
                        example.get("expected_result", "")
                    )
                    results.append(eq_result)
                
                # Validate proofs
                if "proof_steps" in example:
                    proof_result = self.validate_proof(
                        example["proof_steps"],
                        example.get("theorem", "")
                    )
                    results.append(proof_result)
            except Exception as e:
                error_count += 1
                results.append({
                    "error": str(e),
                    "validation_score": 0.0
                })
        
        # Calculate overall metrics
        scores = [r["validation_score"] for r in results if "validation_score" in r]
        if scores:
            avg_score = np.mean(scores)
        else:
            avg_score = 0.0
        
        return {
            "total_examples": len(dataset),
            "processed_examples": len(results),
            "error_count": error_count,
            "average_score": float(avg_score),
            "detailed_results": results
        }

    def save_validation_report(self, report: Dict[str, Any], name: str = None) -> str:
        """Save a validation report"""
        if name is None:
            name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        
        report_path = self.validation_dir / f"report_{name}.json"
        
        # Add timestamp and summary metrics
        report["timestamp"] = str(datetime.datetime.now())
        report["summary"] = {
            "accuracy": report.get("average_score", 0.0),
            "error_rate": report.get("error_count", 0) / report.get("total_examples", 1)
        }
        
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2)
        
        return str(report_path)