Spaces:
Sleeping
Sleeping
| """ | |
| Fallback Chains for API Failures | |
| This module implements intelligent fallback mechanisms that provide alternative | |
| data sources and strategies when primary APIs fail. | |
| """ | |
| import asyncio | |
| import time | |
| import math | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Optional, Any, Callable, Union, Tuple | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| import logging | |
| from functools import wraps | |
| from .error_categorization import CategorizedError, ErrorCategory, ErrorType | |
| from .retry_strategies import RetryConfig, RetryStrategy | |
| class FallbackStrategy(str, Enum): | |
| """Different fallback strategies.""" | |
| SEQUENTIAL = "sequential" # Try fallbacks in order | |
| PARALLEL = "parallel" # Try all fallbacks simultaneously | |
| INTELLIGENT = "intelligent" # Choose fallback based on error type | |
| CACHED_ONLY = "cached_only" # Only use cached data | |
| PARTIAL_RESULTS = "partial_results" # Return partial results if available | |
| class DataSource(str, Enum): | |
| """Available data sources.""" | |
| TAVILY = "tavily" | |
| SERPAPI = "serpapi" | |
| CACHE = "cache" | |
| STATIC_DATA = "static_data" | |
| USER_PREFERENCES = "user_preferences" | |
| HISTORICAL_DATA = "historical_data" | |
| class FallbackSource: | |
| """Definition of a fallback data source.""" | |
| name: str | |
| source_type: DataSource | |
| operation: Callable | |
| priority: int = 1 # Lower number = higher priority | |
| timeout_seconds: float = 30.0 | |
| success_rate: float = 1.0 # Tracked dynamically | |
| last_success: Optional[datetime] = None | |
| last_failure: Optional[datetime] = None | |
| failure_count: int = 0 | |
| success_count: int = 0 | |
| is_available: bool = True | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def get_reliability_score(self) -> float: | |
| """Calculate reliability score for this source.""" | |
| total_attempts = self.success_count + self.failure_count | |
| if total_attempts == 0: | |
| return 1.0 | |
| base_success_rate = self.success_count / total_attempts | |
| # Penalize recent failures | |
| recency_penalty = 0.0 | |
| if self.last_failure and self.last_success: | |
| if self.last_failure > self.last_success: | |
| time_since_failure = (datetime.now() - self.last_failure).total_seconds() | |
| # Reduce penalty over time (exponential decay) | |
| recency_penalty = 0.3 * math.exp(-time_since_failure / 3600) # 1 hour half-life | |
| return max(0.0, base_success_rate - recency_penalty) | |
| class FallbackResult: | |
| """Result of a fallback operation.""" | |
| success: bool | |
| data: Any = None | |
| source_name: str = "" | |
| source_type: DataSource = DataSource.CACHE | |
| execution_time_seconds: float = 0.0 | |
| error: Optional[Exception] = None | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| fallback_chain: List[str] = field(default_factory=list) | |
| class FallbackChain: | |
| """Definition of a fallback chain for a specific operation.""" | |
| operation_name: str | |
| strategy: FallbackStrategy | |
| sources: List[FallbackSource] | |
| max_execution_time: float = 60.0 | |
| min_success_rate: float = 0.5 | |
| cache_fallback_enabled: bool = True | |
| partial_results_enabled: bool = True | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| class FallbackChainManager: | |
| """ | |
| Manager for fallback chains that coordinates multiple data sources | |
| and implements intelligent fallback strategies. | |
| """ | |
| def __init__(self): | |
| self.logger = logging.getLogger(__name__) | |
| self._chains: Dict[str, FallbackChain] = {} | |
| self._source_statistics: Dict[str, Dict[str, Any]] = {} | |
| self._cache_store: Dict[str, Any] = {} | |
| self._cache_timestamps: Dict[str, datetime] = {} | |
| self._cache_ttl_seconds: Dict[str, float] = {} | |
| def register_chain(self, chain: FallbackChain): | |
| """Register a fallback chain.""" | |
| self._chains[chain.operation_name] = chain | |
| # Initialize source statistics | |
| for source in chain.sources: | |
| if source.name not in self._source_statistics: | |
| self._source_statistics[source.name] = { | |
| "total_attempts": 0, | |
| "successful_attempts": 0, | |
| "failed_attempts": 0, | |
| "avg_execution_time": 0.0, | |
| "last_updated": datetime.now() | |
| } | |
| def get_or_create_chain(self, operation_name: str, | |
| sources: List[FallbackSource], | |
| strategy: FallbackStrategy = FallbackStrategy.SEQUENTIAL) -> FallbackChain: | |
| """Get existing chain or create a new one.""" | |
| if operation_name not in self._chains: | |
| chain = FallbackChain( | |
| operation_name=operation_name, | |
| strategy=strategy, | |
| sources=sources | |
| ) | |
| self.register_chain(chain) | |
| return self._chains[operation_name] | |
| def cache_result(self, key: str, data: Any, ttl_seconds: float = 3600): | |
| """Cache a result for fallback use.""" | |
| self._cache_store[key] = data | |
| self._cache_timestamps[key] = datetime.now() | |
| self._cache_ttl_seconds[key] = ttl_seconds | |
| def get_cached_result(self, key: str) -> Optional[Any]: | |
| """Get cached result if available and not expired.""" | |
| if key not in self._cache_store: | |
| return None | |
| # Check if expired | |
| if key in self._cache_timestamps and key in self._cache_ttl_seconds: | |
| age_seconds = (datetime.now() - self._cache_timestamps[key]).total_seconds() | |
| if age_seconds > self._cache_ttl_seconds[key]: | |
| # Remove expired cache | |
| del self._cache_store[key] | |
| del self._cache_timestamps[key] | |
| del self._cache_ttl_seconds[key] | |
| return None | |
| return self._cache_store[key] | |
| async def execute_fallback_chain(self, operation_name: str, | |
| context: Optional[Dict[str, Any]] = None, | |
| cache_key: Optional[str] = None) -> FallbackResult: | |
| """ | |
| Execute a fallback chain for an operation. | |
| Args: | |
| operation_name: Name of the operation | |
| context: Context for the operation | |
| cache_key: Key for caching results | |
| Returns: | |
| FallbackResult with data from the best available source | |
| """ | |
| if operation_name not in self._chains: | |
| return FallbackResult( | |
| success=False, | |
| error=Exception(f"No fallback chain registered for {operation_name}") | |
| ) | |
| chain = self._chains[operation_name] | |
| context = context or {} | |
| # Try cache first if enabled | |
| if chain.cache_fallback_enabled and cache_key: | |
| cached_data = self.get_cached_result(cache_key) | |
| if cached_data: | |
| self.logger.info(f"Using cached data for {operation_name}") | |
| return FallbackResult( | |
| success=True, | |
| data=cached_data, | |
| source_name="cache", | |
| source_type=DataSource.CACHE, | |
| fallback_chain=["cache"] | |
| ) | |
| # Execute based on strategy | |
| if chain.strategy == FallbackStrategy.SEQUENTIAL: | |
| return await self._execute_sequential_fallback(chain, context, cache_key) | |
| elif chain.strategy == FallbackStrategy.PARALLEL: | |
| return await self._execute_parallel_fallback(chain, context, cache_key) | |
| elif chain.strategy == FallbackStrategy.INTELLIGENT: | |
| return await self._execute_intelligent_fallback(chain, context, cache_key) | |
| elif chain.strategy == FallbackStrategy.CACHED_ONLY: | |
| return await self._execute_cached_only_fallback(chain, context, cache_key) | |
| else: | |
| return await self._execute_sequential_fallback(chain, context, cache_key) | |
| async def _execute_sequential_fallback(self, chain: FallbackChain, | |
| context: Dict[str, Any], | |
| cache_key: Optional[str]) -> FallbackResult: | |
| """Execute fallback sources sequentially.""" | |
| fallback_chain = [] | |
| partial_results = [] | |
| # Sort sources by priority and reliability | |
| sorted_sources = sorted(chain.sources, key=lambda s: (s.priority, -s.get_reliability_score())) | |
| for source in sorted_sources: | |
| if not source.is_available: | |
| continue | |
| fallback_chain.append(source.name) | |
| start_time = time.time() | |
| try: | |
| # Execute source with timeout | |
| if asyncio.iscoroutinefunction(source.operation): | |
| data = await asyncio.wait_for( | |
| source.operation(**context), | |
| timeout=source.timeout_seconds | |
| ) | |
| else: | |
| # For sync functions, run in thread | |
| loop = asyncio.get_event_loop() | |
| data = await loop.run_in_executor( | |
| None, | |
| lambda: source.operation(**context) | |
| ) | |
| execution_time = time.time() - start_time | |
| # Record success | |
| self._record_source_result(source, True, execution_time) | |
| # Cache result if key provided | |
| if cache_key: | |
| self.cache_result(cache_key, data) | |
| return FallbackResult( | |
| success=True, | |
| data=data, | |
| source_name=source.name, | |
| source_type=source.source_type, | |
| execution_time_seconds=execution_time, | |
| fallback_chain=fallback_chain | |
| ) | |
| except Exception as error: | |
| execution_time = time.time() - start_time | |
| self._record_source_result(source, False, execution_time) | |
| # Store partial results if available | |
| if hasattr(error, 'partial_results') and error.partial_results: | |
| partial_results.extend(error.partial_results) | |
| self.logger.warning(f"Fallback source {source.name} failed: {error}") | |
| continue | |
| # All sources failed - return partial results if available | |
| if chain.partial_results_enabled and partial_results: | |
| return FallbackResult( | |
| success=True, | |
| data=partial_results, | |
| source_name="partial_results", | |
| source_type=DataSource.CACHE, | |
| fallback_chain=fallback_chain, | |
| metadata={"partial_results": True} | |
| ) | |
| return FallbackResult( | |
| success=False, | |
| error=Exception(f"All fallback sources failed for {chain.operation_name}"), | |
| fallback_chain=fallback_chain | |
| ) | |
| async def _execute_parallel_fallback(self, chain: FallbackChain, | |
| context: Dict[str, Any], | |
| cache_key: Optional[str]) -> FallbackResult: | |
| """Execute fallback sources in parallel.""" | |
| fallback_chain = [] | |
| available_sources = [s for s in chain.sources if s.is_available] | |
| if not available_sources: | |
| return FallbackResult( | |
| success=False, | |
| error=Exception("No available fallback sources"), | |
| fallback_chain=[] | |
| ) | |
| # Create tasks for all sources | |
| tasks = [] | |
| for source in available_sources: | |
| fallback_chain.append(source.name) | |
| task = asyncio.create_task(self._execute_single_source(source, context)) | |
| tasks.append((source, task)) | |
| try: | |
| # Wait for first successful result | |
| for source, task in tasks: | |
| try: | |
| result = await asyncio.wait_for(task, timeout=source.timeout_seconds) | |
| if result.success: | |
| # Cancel other tasks | |
| for _, other_task in tasks: | |
| if other_task != task and not other_task.done(): | |
| other_task.cancel() | |
| # Cache result if key provided | |
| if cache_key: | |
| self.cache_result(cache_key, result.data) | |
| return FallbackResult( | |
| success=True, | |
| data=result.data, | |
| source_name=source.name, | |
| source_type=source.source_type, | |
| execution_time_seconds=result.execution_time_seconds, | |
| fallback_chain=fallback_chain | |
| ) | |
| except asyncio.TimeoutError: | |
| self._record_source_result(source, False, source.timeout_seconds) | |
| continue | |
| except Exception as error: | |
| self._record_source_result(source, False, 0.0) | |
| continue | |
| # No sources succeeded | |
| return FallbackResult( | |
| success=False, | |
| error=Exception("All parallel fallback sources failed"), | |
| fallback_chain=fallback_chain | |
| ) | |
| finally: | |
| # Ensure all tasks are cleaned up | |
| for _, task in tasks: | |
| if not task.done(): | |
| task.cancel() | |
| async def _execute_intelligent_fallback(self, chain: FallbackChain, | |
| context: Dict[str, Any], | |
| cache_key: Optional[str]) -> FallbackResult: | |
| """Execute intelligent fallback based on error analysis.""" | |
| # For now, use sequential with intelligent source ordering | |
| # In a more sophisticated implementation, this could analyze | |
| # the specific error type and choose the most appropriate source | |
| # Sort by reliability score and priority | |
| sorted_sources = sorted(chain.sources, key=lambda s: ( | |
| s.priority, | |
| -s.get_reliability_score(), | |
| -s.success_rate | |
| )) | |
| # Update chain sources order | |
| chain.sources = sorted_sources | |
| return await self._execute_sequential_fallback(chain, context, cache_key) | |
| async def _execute_cached_only_fallback(self, chain: FallbackChain, | |
| context: Dict[str, Any], | |
| cache_key: Optional[str]) -> FallbackResult: | |
| """Execute cached-only fallback.""" | |
| if cache_key: | |
| cached_data = self.get_cached_result(cache_key) | |
| if cached_data: | |
| return FallbackResult( | |
| success=True, | |
| data=cached_data, | |
| source_name="cache", | |
| source_type=DataSource.CACHE, | |
| fallback_chain=["cache"] | |
| ) | |
| return FallbackResult( | |
| success=False, | |
| error=Exception("No cached data available"), | |
| fallback_chain=[] | |
| ) | |
| async def _execute_single_source(self, source: FallbackSource, | |
| context: Dict[str, Any]) -> FallbackResult: | |
| """Execute a single fallback source.""" | |
| start_time = time.time() | |
| try: | |
| if asyncio.iscoroutinefunction(source.operation): | |
| data = await source.operation(**context) | |
| else: | |
| loop = asyncio.get_event_loop() | |
| data = await loop.run_in_executor(None, lambda: source.operation(**context)) | |
| execution_time = time.time() - start_time | |
| self._record_source_result(source, True, execution_time) | |
| return FallbackResult( | |
| success=True, | |
| data=data, | |
| source_name=source.name, | |
| source_type=source.source_type, | |
| execution_time_seconds=execution_time | |
| ) | |
| except Exception as error: | |
| execution_time = time.time() - start_time | |
| self._record_source_result(source, False, execution_time) | |
| return FallbackResult( | |
| success=False, | |
| error=error, | |
| source_name=source.name, | |
| source_type=source.source_type, | |
| execution_time_seconds=execution_time | |
| ) | |
| def _record_source_result(self, source: FallbackSource, success: bool, execution_time: float): | |
| """Record the result of a source execution.""" | |
| if success: | |
| source.success_count += 1 | |
| source.last_success = datetime.now() | |
| else: | |
| source.failure_count += 1 | |
| source.last_failure = datetime.now() | |
| # Update global statistics | |
| if source.name not in self._source_statistics: | |
| self._source_statistics[source.name] = { | |
| "total_attempts": 0, | |
| "successful_attempts": 0, | |
| "failed_attempts": 0, | |
| "avg_execution_time": 0.0, | |
| "last_updated": datetime.now() | |
| } | |
| stats = self._source_statistics[source.name] | |
| stats["total_attempts"] += 1 | |
| if success: | |
| stats["successful_attempts"] += 1 | |
| else: | |
| stats["failed_attempts"] += 1 | |
| # Update average execution time | |
| total_time = stats["avg_execution_time"] * (stats["total_attempts"] - 1) | |
| stats["avg_execution_time"] = (total_time + execution_time) / stats["total_attempts"] | |
| stats["last_updated"] = datetime.now() | |
| # Update source availability based on failure rate | |
| total_attempts = source.success_count + source.failure_count | |
| if total_attempts >= 5: # Only after sufficient data | |
| failure_rate = source.failure_count / total_attempts | |
| source.is_available = failure_rate < 0.8 # Mark unavailable if >80% failure rate | |
| def get_source_statistics(self, source_name: Optional[str] = None) -> Dict[str, Any]: | |
| """Get statistics for fallback sources.""" | |
| if source_name: | |
| return self._source_statistics.get(source_name, {}) | |
| return self._source_statistics | |
| def get_chain_statistics(self, operation_name: Optional[str] = None) -> Dict[str, Any]: | |
| """Get statistics for fallback chains.""" | |
| if operation_name and operation_name in self._chains: | |
| chain = self._chains[operation_name] | |
| return { | |
| "operation_name": chain.operation_name, | |
| "strategy": chain.strategy.value, | |
| "sources": [ | |
| { | |
| "name": s.name, | |
| "source_type": s.source_type.value, | |
| "priority": s.priority, | |
| "success_rate": s.get_reliability_score(), | |
| "is_available": s.is_available, | |
| "success_count": s.success_count, | |
| "failure_count": s.failure_count | |
| } | |
| for s in chain.sources | |
| ] | |
| } | |
| else: | |
| return { | |
| "chains": list(self._chains.keys()), | |
| "total_chains": len(self._chains) | |
| } | |
| # Convenience functions for creating common fallback chains | |
| def create_api_fallback_chain(operation_name: str, primary_api: Callable, | |
| fallback_apis: List[Callable], | |
| cache_fallback: bool = True) -> FallbackChain: | |
| """Create a fallback chain for API operations.""" | |
| sources = [] | |
| # Primary API | |
| sources.append(FallbackSource( | |
| name="primary_api", | |
| source_type=DataSource.TAVILY, | |
| operation=primary_api, | |
| priority=1, | |
| timeout_seconds=30.0 | |
| )) | |
| # Fallback APIs | |
| for i, api in enumerate(fallback_apis): | |
| source_type = DataSource.SERPAPI if i == 0 else DataSource.STATIC_DATA | |
| sources.append(FallbackSource( | |
| name=f"fallback_api_{i+1}", | |
| source_type=source_type, | |
| operation=api, | |
| priority=i+2, | |
| timeout_seconds=30.0 | |
| )) | |
| # Cache fallback | |
| if cache_fallback: | |
| sources.append(FallbackSource( | |
| name="cache", | |
| source_type=DataSource.CACHE, | |
| operation=lambda **kwargs: None, # Will be handled by cache logic | |
| priority=999, | |
| timeout_seconds=1.0 | |
| )) | |
| return FallbackChain( | |
| operation_name=operation_name, | |
| strategy=FallbackStrategy.SEQUENTIAL, | |
| sources=sources, | |
| cache_fallback_enabled=cache_fallback | |
| ) | |
| def create_search_fallback_chain(operation_name: str, | |
| tavily_search: Callable, | |
| serpapi_search: Callable, | |
| cached_search: Callable) -> FallbackChain: | |
| """Create a fallback chain specifically for search operations.""" | |
| sources = [ | |
| FallbackSource( | |
| name="tavily_search", | |
| source_type=DataSource.TAVILY, | |
| operation=tavily_search, | |
| priority=1, | |
| timeout_seconds=15.0 | |
| ), | |
| FallbackSource( | |
| name="serpapi_search", | |
| source_type=DataSource.SERPAPI, | |
| operation=serpapi_search, | |
| priority=2, | |
| timeout_seconds=20.0 | |
| ), | |
| FallbackSource( | |
| name="cached_search", | |
| source_type=DataSource.CACHE, | |
| operation=cached_search, | |
| priority=3, | |
| timeout_seconds=5.0 | |
| ) | |
| ] | |
| return FallbackChain( | |
| operation_name=operation_name, | |
| strategy=FallbackStrategy.INTELLIGENT, | |
| sources=sources, | |
| cache_fallback_enabled=True, | |
| partial_results_enabled=True | |
| ) | |
| # Global fallback chain manager | |
| _global_fallback_manager: Optional[FallbackChainManager] = None | |
| def get_global_fallback_manager() -> FallbackChainManager: | |
| """Get the global fallback chain manager instance.""" | |
| global _global_fallback_manager | |
| if _global_fallback_manager is None: | |
| _global_fallback_manager = FallbackChainManager() | |
| return _global_fallback_manager | |
| # Decorator for fallback chains | |
| def with_fallback_chain(operation_name: str, chain: Optional[FallbackChain] = None): | |
| """Decorator to add fallback chain to a function.""" | |
| def decorator(func: Callable) -> Callable: | |
| async def async_wrapper(*args, **kwargs): | |
| manager = get_global_fallback_manager() | |
| # If no chain provided, try to get existing one | |
| if chain is None: | |
| if operation_name not in manager._chains: | |
| # If no chain exists, just execute the function normally | |
| return await func(*args, **kwargs) | |
| fallback_chain = manager._chains[operation_name] | |
| else: | |
| fallback_chain = chain | |
| # Create cache key from function arguments | |
| cache_key = f"{operation_name}:{hash(str(args) + str(sorted(kwargs.items())))}" | |
| # Execute fallback chain | |
| context = {"args": args, "kwargs": kwargs, "function": func} | |
| result = await manager.execute_fallback_chain( | |
| operation_name, context, cache_key | |
| ) | |
| if result.success: | |
| return result.data | |
| else: | |
| raise result.error | |
| def sync_wrapper(*args, **kwargs): | |
| # For sync functions, run in event loop | |
| async def async_func(): | |
| return await async_wrapper(*args, **kwargs) | |
| try: | |
| loop = asyncio.get_event_loop() | |
| return loop.run_until_complete(async_func()) | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| return loop.run_until_complete(async_func()) | |
| finally: | |
| loop.close() | |
| if asyncio.iscoroutinefunction(func): | |
| return async_wrapper | |
| else: | |
| return sync_wrapper | |
| return decorator | |