Spaces:
Sleeping
Sleeping
File size: 3,942 Bytes
5df8a73 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
Error Rate Tracker - Track error rates per provider with alerting.
"""
from collections import defaultdict, deque
import logging
import threading
import time
from typing import Callable, Dict, Optional
logger = logging.getLogger(__name__)
class ErrorRateTracker:
"""
Tracks error rates per provider with sliding window.
"""
def __init__(
self,
window_size: int = 60,
threshold: float = 0.5,
alert_callback: Optional[Callable[[str, float], None]] = None,
):
self.window_size = window_size # seconds
self.threshold = threshold # failure rate threshold
self.alert_callback = alert_callback
self._lock = threading.RLock() # Use RLock to allow reentrant locking
self._errors: Dict[str, deque[float]] = defaultdict(deque)
self._total_calls: Dict[str, deque[float]] = defaultdict(deque)
self._alerted: Dict[str, bool] = defaultdict(bool) # to avoid repeated alerts
def record_call(self, provider: str, success: bool):
"""Record a call for the provider."""
now = time.time()
with self._lock:
self._total_calls[provider].append(now)
if not success:
self._errors[provider].append(now)
self._cleanup_old_entries(provider, now)
self._check_alert(provider)
def get_error_rate(self, provider: str) -> float:
"""Get current error rate for provider."""
now = time.time()
with self._lock:
self._cleanup_old_entries(provider, now)
total = len(self._total_calls[provider])
errors = len(self._errors[provider])
return errors / total if total > 0 else 0.0
def check_threshold(self, provider: str) -> bool:
"""Check if error rate exceeds threshold."""
rate = self.get_error_rate(provider)
return rate > self.threshold
def _check_alert(self, provider: str):
"""Check and trigger alert if needed."""
rate = self.get_error_rate(provider)
exceeds_threshold = rate > self.threshold
if exceeds_threshold and not self._alerted[provider]:
logger.warning(
f"Provider {provider} error rate {rate:.2%} exceeds threshold {self.threshold:.2%}"
)
if self.alert_callback:
self.alert_callback(provider, rate)
self._alerted[provider] = True
elif not exceeds_threshold:
self._alerted[provider] = False # reset when below threshold
def _cleanup_old_entries(self, provider: str, now: float):
"""Remove entries older than window_size."""
cutoff = now - self.window_size
while self._total_calls[provider] and self._total_calls[provider][0] <= cutoff:
self._total_calls[provider].popleft()
while self._errors[provider] and self._errors[provider][0] <= cutoff:
self._errors[provider].popleft()
# Global instance
tracker = ErrorRateTracker()
# Set alert callback to circuit breaker
try:
from .network.circuit_breaker import alert_callback as cb
tracker.alert_callback = cb
except ImportError as e:
logging.getLogger(__name__).warning(
f"Circuit breaker module not available: {e}. Error rate tracking will work but circuit breaker integration is disabled."
)
def record_provider_call(provider: str, success: bool):
"""Global function to record a call."""
tracker.record_call(provider, success)
def get_provider_error_rate(provider: str) -> float:
"""Get error rate for provider."""
return tracker.get_error_rate(provider)
def check_provider_threshold(provider: str) -> bool:
"""Check if provider exceeds threshold."""
return tracker.check_threshold(provider)
def set_alert_callback(callback: Callable[[str, float], None]):
"""Set the alert callback for the tracker."""
tracker.alert_callback = callback
|