File size: 5,131 Bytes
95594cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""

Automatic Error Correction

Applies corrections to fixable errors in solution steps

Tracks correction success rates

"""

import re
from typing import List, Dict, Any
from sympy import sympify, simplify, N


def correct_solution(steps: List[str], errors: List[Dict[str, Any]]) -> Dict[str, Any]:
    """

    Automatically correct fixable errors in solution steps.

    

    Args:

        steps: Original solution steps list

        errors: List of error dictionaries

        

    Returns:

        Dictionary with corrected steps, correction log, and success rate

    """
    corrected_steps = steps.copy()
    correction_log = []
    manual_review_needed = []
    fixed_count = 0
    
    for error in errors:
        step_number = error.get("step_number", 0) - 1  # Convert to 0-based index
        if step_number < 0 or step_number >= len(corrected_steps):
            continue
        
        error_type = error.get("type", "")
        fixable = error.get("fixable", False)
        
        if not fixable:
            manual_review_needed.append(error)
            continue
        
        # Attempt correction based on error type
        if error_type == "calculation_error":
            success = _correct_arithmetic_error(corrected_steps, step_number, error)
            if success:
                fixed_count += 1
                correction_log.append({
                    "step": step_number + 1,
                    "type": "arithmetic",
                    "original": steps[step_number],
                    "corrected": corrected_steps[step_number],
                    "reason": "Arithmetic calculation corrected"
                })
            else:
                manual_review_needed.append(error)
        
        elif error_type == "operation_mismatch":
            success = _correct_operation_mismatch(corrected_steps, step_number, error)
            if success:
                fixed_count += 1
                correction_log.append({
                    "step": step_number + 1,
                    "type": "operation_mismatch",
                    "original": steps[step_number],
                    "corrected": corrected_steps[step_number],
                    "reason": "Operation mismatch corrected"
                })
            else:
                manual_review_needed.append(error)
        
        else:
            # Other error types need manual review
            manual_review_needed.append(error)
    
    # Calculate success rate
    total_fixable = len([e for e in errors if e.get("fixable", False)])
    if total_fixable > 0:
        success_rate = fixed_count / total_fixable
    else:
        success_rate = 0.0
    
    return {
        "corrected_steps": corrected_steps,
        "correction_log": correction_log,
        "success_rate": success_rate,
        "manual_review_needed": manual_review_needed,
        "fixed_count": fixed_count,
        "total_fixable": total_fixable
    }


def _correct_arithmetic_error(steps: List[str], step_index: int, error: Dict[str, Any]) -> bool:
    """Correct arithmetic calculation error."""
    try:
        found = error.get("found", "")
        correct = error.get("correct", "")
        
        # Extract the incorrect result from found
        found_nums = re.findall(r'\d+\.?\d*', found)
        correct_nums = re.findall(r'\d+\.?\d*', correct)
        
        if not found_nums or not correct_nums:
            return False
        
        incorrect_result = found_nums[-1]
        correct_result = correct_nums[-1]
        
        # Replace incorrect result with correct result in the step
        step = steps[step_index]
        # Replace the last occurrence of the incorrect result
        corrected_step = step.replace(incorrect_result, correct_result, 1)
        
        # If that didn't work, try more sophisticated replacement
        if corrected_step == step:
            # Try replacing the full expression
            corrected_step = step.replace(found, correct)
        
        steps[step_index] = corrected_step
        return True
    except Exception as e:
        return False


def _correct_operation_mismatch(steps: List[str], step_index: int, error: Dict[str, Any]) -> bool:
    """Correct operation mismatch error."""
    try:
        description = error.get("description", "")
        step = steps[step_index]
        
        # Extract operation from description
        # This is simplified - in production would use more sophisticated NLP
        if "subtract" in description.lower() and "+" in step:
            # Replace + with -
            corrected_step = step.replace("+", "-", 1)
            steps[step_index] = corrected_step
            return True
        elif "add" in description.lower() and "-" in step:
            # Replace - with +
            corrected_step = step.replace("-", "+", 1)
            steps[step_index] = corrected_step
            return True
        
        return False
    except Exception as e:
        return False