File size: 9,993 Bytes
8a682b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
Enhanced error handling and recovery mechanisms for the AI agent.
"""

import time
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from enum import Enum

logger = logging.getLogger(__name__)

class ErrorCategory(Enum):
    """Enhanced error categories for better error handling."""
    # API and Rate Limiting
    RATE_LIMIT = "rate_limit"
    NETWORK = "network"
    AUTH = "auth"
    
    # Data and Validation
    TOOL_VALIDATION = "tool_validation"
    NOT_FOUND = "not_found"
    DATA_FORMAT = "data_format"
    
    # Model and Processing
    MODEL_ERROR = "model_error"
    RESOURCE_LIMIT = "resource_limit"
    PROCESSING_TIMEOUT = "processing_timeout"
    
    # Logic and Reasoning
    LOGIC_ERROR = "logic_error"
    INFERENCE_ERROR = "inference_error"
    
    # Default
    GENERAL = "general"

@dataclass
class RetryStrategy:
    """Configuration for retry behavior."""
    max_retries: int
    backoff_factor: float
    timeout_increase: float
    use_exponential_backoff: bool
    should_switch_tool: bool
    should_validate_input: bool
    should_reduce_scope: bool
    should_verify_reasoning: bool

@dataclass
class ToolExecutionResult:
    """Result of a tool execution attempt."""
    success: bool
    output: Any
    error: Optional[Exception]
    error_category: Optional[ErrorCategory]
    retry_suggestions: List[str]

class CircuitBreaker:
    """Circuit breaker pattern for fault tolerance."""
    
    def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.failure_count = 0
        self.last_failure_time = None
        self.state = "closed"  # closed, open, half-open
        
    def call(self, func, *args, **kwargs):
        """Execute function with circuit breaker protection."""
        if self.state == "open":
            if time.time() - self.last_failure_time > self.recovery_timeout:
                self.state = "half-open"
            else:
                raise Exception("Circuit breaker is open")
        
        try:
            result = func(*args, **kwargs)
            if self.state == "half-open":
                self.state = "closed"
                self.failure_count = 0
            return result
            
        except Exception as e:
            self.failure_count += 1
            self.last_failure_time = time.time()
            
            if self.failure_count >= self.failure_threshold:
                self.state = "open"
                logger.error(f"Circuit breaker opened after {self.failure_count} failures")
                
            raise e

class ErrorHandler:
    """Enhanced error handling and recovery system."""
    
    def __init__(self):
        self.error_counts = {}
        self.recovery_history = {}
        self.circuit_breakers = {}  # Add circuit breakers per tool
    
    def categorize_error(self, error_str: str) -> ErrorCategory:
        """Categorize error with enhanced granularity."""
        error_lower = str(error_str).lower()
        
        # API and Rate Limiting
        if "429" in error_lower or "rate limit" in error_lower:
            return ErrorCategory.RATE_LIMIT
        elif "timeout" in error_lower or "connection" in error_lower:
            return ErrorCategory.NETWORK
        elif "authentication" in error_lower or "401" in error_lower:
            return ErrorCategory.AUTH
        
        # Data and Validation
        elif "validation" in error_lower or "invalid" in error_lower:
            return ErrorCategory.TOOL_VALIDATION
        elif "not found" in error_lower or "404" in error_lower:
            return ErrorCategory.NOT_FOUND
        elif "format" in error_lower or "parse" in error_lower:
            return ErrorCategory.DATA_FORMAT
        
        # Model and Processing
        elif "model" in error_lower and "decommissioned" in error_lower:
            return ErrorCategory.MODEL_ERROR
        elif "memory" in error_lower or "resource" in error_lower:
            return ErrorCategory.RESOURCE_LIMIT
        elif "timeout" in error_lower or "deadline" in error_lower:
            return ErrorCategory.PROCESSING_TIMEOUT
        
        # Logic and Reasoning
        elif "logic" in error_lower or "reasoning" in error_lower:
            return ErrorCategory.LOGIC_ERROR
        elif "inference" in error_lower or "prediction" in error_lower:
            return ErrorCategory.INFERENCE_ERROR
        
        return ErrorCategory.GENERAL
    
    def get_retry_strategy(self, error_category: ErrorCategory, state: Dict[str, Any]) -> RetryStrategy:
        """Get sophisticated retry strategy based on error category and state."""
        base_strategy = RetryStrategy(
            max_retries=3,
            backoff_factor=1.5,
            timeout_increase=1.5,
            use_exponential_backoff=True,
            should_switch_tool=False,
            should_validate_input=False,
            should_reduce_scope=False,
            should_verify_reasoning=False
        )
        
        # Customize strategy based on error type
        if error_category == ErrorCategory.RATE_LIMIT:
            return RetryStrategy(
                max_retries=5,
                backoff_factor=2.0,
                timeout_increase=1.0,
                use_exponential_backoff=True,
                should_switch_tool=True,
                should_validate_input=False,
                should_reduce_scope=False,
                should_verify_reasoning=False
            )
        elif error_category == ErrorCategory.NETWORK:
            return RetryStrategy(
                max_retries=3,
                backoff_factor=1.2,
                timeout_increase=1.0,
                use_exponential_backoff=True,
                should_switch_tool=False,
                should_validate_input=False,
                should_reduce_scope=False,
                should_verify_reasoning=False
            )
        elif error_category == ErrorCategory.TOOL_VALIDATION:
            return RetryStrategy(
                max_retries=2,
                backoff_factor=1.0,
                timeout_increase=1.0,
                use_exponential_backoff=False,
                should_switch_tool=True,
                should_validate_input=True,
                should_reduce_scope=False,
                should_verify_reasoning=False
            )
        elif error_category == ErrorCategory.RESOURCE_LIMIT:
            return RetryStrategy(
                max_retries=2,
                backoff_factor=1.0,
                timeout_increase=1.0,
                use_exponential_backoff=False,
                should_switch_tool=True,
                should_validate_input=False,
                should_reduce_scope=True,
                should_verify_reasoning=False
            )
        elif error_category == ErrorCategory.LOGIC_ERROR:
            return RetryStrategy(
                max_retries=2,
                backoff_factor=1.0,
                timeout_increase=1.0,
                use_exponential_backoff=False,
                should_switch_tool=True,
                should_validate_input=False,
                should_reduce_scope=False,
                should_verify_reasoning=True
            )
        
        return base_strategy
    
    def get_retry_suggestions(self, error_category: ErrorCategory) -> List[str]:
        """Get helpful retry suggestions based on error category."""
        suggestions = {
            ErrorCategory.RATE_LIMIT: [
                "Wait before retrying",
                "Use exponential backoff",
                "Consider alternative tool"
            ],
            ErrorCategory.NETWORK: [
                "Check network connection",
                "Retry with longer timeout",
                "Try smaller request"
            ],
            ErrorCategory.TOOL_VALIDATION: [
                "Fix parameter names",
                "Check data types",
                "Review tool documentation"
            ],
            ErrorCategory.RESOURCE_LIMIT: [
                "Reduce request scope",
                "Use simpler analysis",
                "Try alternative approach"
            ],
            ErrorCategory.LOGIC_ERROR: [
                "Verify reasoning steps",
                "Check assumptions",
                "Try different approach"
            ],
            ErrorCategory.GENERAL: [
                "Check tool availability",
                "Review input format",
                "Consider alternative approach"
            ]
        }
        return suggestions.get(error_category, ["Retry with modified input"])
    
    def track_error(self, error_category: ErrorCategory):
        """Track error frequency for adaptive handling."""
        self.error_counts[error_category] = self.error_counts.get(error_category, 0) + 1
    
    def get_error_stats(self) -> Dict[ErrorCategory, int]:
        """Get error statistics for monitoring."""
        return self.error_counts.copy()
    
    def record_recovery(self, error_category: ErrorCategory, success: bool):
        """Record recovery attempt success/failure."""
        if error_category not in self.recovery_history:
            self.recovery_history[error_category] = {"success": 0, "failure": 0}
        
        if success:
            self.recovery_history[error_category]["success"] += 1
        else:
            self.recovery_history[error_category]["failure"] += 1
    
    def get_recovery_stats(self) -> Dict[ErrorCategory, Dict[str, int]]:
        """Get recovery statistics for monitoring."""
        return self.recovery_history.copy()
    
    def get_circuit_breaker(self, tool_name: str) -> CircuitBreaker:
        """Get or create circuit breaker for tool."""
        if tool_name not in self.circuit_breakers:
            self.circuit_breakers[tool_name] = CircuitBreaker()
        return self.circuit_breakers[tool_name]