AIDA / app /core /error_handling.py
destinyebuka's picture
dora
4c9881b
# ============================================================
# app/core/error_handling.py - Error Handling & Resilience
# ============================================================
import logging
import asyncio
from typing import Callable, Any, Optional, Tuple, TypeVar, Awaitable
from functools import wraps
from enum import Enum
import time
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
retry_if_result,
RetryError,
)
from app.core.observability import trace_operation
logger = logging.getLogger(__name__)
T = TypeVar('T')
# ============================================================
# Error Types
# ============================================================
class LojizError(Exception):
"""Base exception for Lojiz system"""
def __init__(
self,
message: str,
error_code: str = "INTERNAL_ERROR",
status_code: int = 500,
recoverable: bool = False,
context: Optional[dict] = None,
):
self.message = message
self.error_code = error_code
self.status_code = status_code
self.recoverable = recoverable
self.context = context or {}
super().__init__(self.message)
class LLMError(LojizError):
"""LLM-related errors"""
def __init__(self, message: str, recoverable: bool = True, **kwargs):
super().__init__(
message,
error_code="LLM_ERROR",
status_code=503,
recoverable=recoverable,
**kwargs
)
class VectorDBError(LojizError):
"""Vector database errors"""
def __init__(self, message: str, recoverable: bool = True, **kwargs):
super().__init__(
message,
error_code="VECTOR_DB_ERROR",
status_code=503,
recoverable=recoverable,
**kwargs
)
class CacheError(LojizError):
"""Cache/Redis errors"""
def __init__(self, message: str, recoverable: bool = True, **kwargs):
super().__init__(
message,
error_code="CACHE_ERROR",
status_code=503,
recoverable=recoverable,
**kwargs
)
class DatabaseError(LojizError):
"""Database errors"""
def __init__(self, message: str, recoverable: bool = True, **kwargs):
super().__init__(
message,
error_code="DATABASE_ERROR",
status_code=503,
recoverable=recoverable,
**kwargs
)
class ValidationError(LojizError):
"""Validation errors"""
def __init__(self, message: str, **kwargs):
super().__init__(
message,
error_code="VALIDATION_ERROR",
status_code=400,
recoverable=False,
**kwargs
)
# ============================================================
# Retry Strategies
# ============================================================
class RetryStrategy(Enum):
"""Retry strategies for different scenarios"""
AGGRESSIVE = {
"max_attempts": 5,
"initial_wait": 1,
"max_wait": 30,
}
MODERATE = {
"max_attempts": 3,
"initial_wait": 2,
"max_wait": 10,
}
CONSERVATIVE = {
"max_attempts": 2,
"initial_wait": 5,
"max_wait": 15,
}
def create_retry_decorator(strategy: RetryStrategy = RetryStrategy.MODERATE):
"""Create retry decorator with given strategy"""
config = strategy.value
return retry(
stop=stop_after_attempt(config["max_attempts"]),
wait=wait_exponential(
multiplier=config["initial_wait"],
min=config["initial_wait"],
max=config["max_wait"]
),
retry=retry_if_exception_type((
asyncio.TimeoutError,
ConnectionError,
TimeoutError,
OSError,
)),
reraise=True,
)
# ============================================================
# Fallback Chain
# ============================================================
async def call_with_fallback(
primary: Callable[..., Awaitable[T]],
fallback: Optional[Callable[..., Awaitable[T]]] = None,
error_handler: Optional[Callable[[Exception], T]] = None,
timeout: int = 30,
) -> T:
"""
Call primary function with fallback chain
Args:
primary: Primary async function to call
fallback: Fallback async function if primary fails
error_handler: Function to handle final error
timeout: Timeout for operations
Returns:
Result from primary, fallback, or error_handler
"""
try:
return await asyncio.wait_for(primary(), timeout=timeout)
except Exception as primary_error:
logger.warning(f"⚠️ Primary call failed: {primary_error}")
if fallback:
try:
logger.info("πŸ”„ Trying fallback...")
return await asyncio.wait_for(fallback(), timeout=timeout)
except Exception as fallback_error:
logger.warning(f"⚠️ Fallback also failed: {fallback_error}")
if error_handler:
logger.info("πŸ”„ Using error handler...")
return error_handler(fallback_error)
raise fallback_error
if error_handler:
return error_handler(primary_error)
raise primary_error
# ============================================================
# Decorators
# ============================================================
def async_retry(
strategy: RetryStrategy = RetryStrategy.MODERATE,
operation_name: str = None,
):
"""Decorator for async functions with retry and tracing"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
op_name = operation_name or f"{func.__module__}.{func.__name__}"
retry_decorator = create_retry_decorator(strategy)
retry_func = retry_decorator(func)
with trace_operation(op_name):
try:
return await retry_func(*args, **kwargs)
except RetryError as e:
logger.error(f"❌ {op_name} failed after retries: {e}")
raise
return wrapper
return decorator
def handle_errors(default_return: Any = None):
"""Decorator to handle errors gracefully"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except LojizError as e:
logger.error(
f"❌ {func.__name__} error [{e.error_code}]: {e.message}",
extra={"context": e.context}
)
raise
except Exception as e:
logger.error(
f"❌ Unexpected error in {func.__name__}: {str(e)}",
exc_info=True
)
if default_return is not None:
return default_return
raise LojizError(
f"Unexpected error in {func.__name__}",
recoverable=True,
)
return wrapper
return decorator
# ============================================================
# Circuit Breaker (for external services)
# ============================================================
class CircuitBreaker:
"""Simple circuit breaker for external service calls"""
def __init__(
self,
name: str,
failure_threshold: int = 5,
recovery_timeout: int = 60,
):
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.failures = 0
self.last_failure_time = None
self.is_open = False
async def call(
self,
func: Callable[..., Awaitable[T]],
*args,
**kwargs
) -> T:
"""Call function through circuit breaker"""
# Check if circuit should close (recovery timeout passed)
if self.is_open:
if self._should_attempt_reset():
logger.info(f"πŸ”„ Attempting to reset circuit: {self.name}")
self.is_open = False
self.failures = 0
else:
raise LojizError(
f"Circuit breaker open for {self.name}",
error_code="CIRCUIT_BREAKER_OPEN",
recoverable=True,
)
try:
result = await asyncio.wait_for(func(*args, **kwargs), timeout=30)
# Reset on success
self.failures = 0
self.last_failure_time = None
return result
except Exception as e:
self.failures += 1
self.last_failure_time = time.time()
if self.failures >= self.failure_threshold:
logger.error(
f"πŸ”΄ Circuit breaker opened for {self.name} "
f"(failures: {self.failures})"
)
self.is_open = True
raise
def _should_attempt_reset(self) -> bool:
"""Check if enough time has passed to attempt reset"""
if not self.last_failure_time:
return False
return time.time() - self.last_failure_time >= self.recovery_timeout
def get_status(self) -> dict:
"""Get circuit breaker status"""
return {
"name": self.name,
"is_open": self.is_open,
"failures": self.failures,
"failure_threshold": self.failure_threshold,
}
# ============================================================
# Global Circuit Breakers
# ============================================================
_circuit_breakers = {}
def get_circuit_breaker(name: str) -> CircuitBreaker:
"""Get or create circuit breaker"""
if name not in _circuit_breakers:
_circuit_breakers[name] = CircuitBreaker(name)
return _circuit_breakers[name]
def get_all_circuit_breaker_status() -> dict:
"""Get status of all circuit breakers"""
return {
name: cb.get_status()
for name, cb in _circuit_breakers.items()
}