Spaces:
Sleeping
Sleeping
| import asyncio | |
| import time | |
| from enum import Enum | |
| from typing import Callable, Any, Awaitable | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class CircuitState(Enum): | |
| CLOSED = "closed" # Normal operation | |
| OPEN = "open" # Trip when failures exceed threshold | |
| HALF_OPEN = "half_open" # Test if service recovered | |
| class CircuitBreaker: | |
| def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: tuple = (Exception,)): | |
| self.failure_threshold = failure_threshold | |
| self.recovery_timeout = recovery_timeout | |
| self.expected_exception = expected_exception | |
| self.state = CircuitState.CLOSED | |
| self.failure_count = 0 | |
| self.last_failure_time = None | |
| self._lock = asyncio.Lock() | |
| async def call(self, func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any: | |
| async with self._lock: | |
| if self.state == CircuitState.OPEN: | |
| if time.time() - self.last_failure_time >= self.recovery_timeout: | |
| self.state = CircuitState.HALF_OPEN | |
| logger.info("Circuit breaker transitioning to HALF_OPEN") | |
| else: | |
| raise Exception("Circuit breaker is OPEN") | |
| if self.state == CircuitState.HALF_OPEN: | |
| try: | |
| result = await func(*args, **kwargs) | |
| async with self._lock: | |
| self.state = CircuitState.CLOSED | |
| self.failure_count = 0 | |
| logger.info("Circuit breaker closed after successful call") | |
| return result | |
| except self.expected_exception: | |
| async with self._lock: | |
| self.state = CircuitState.OPEN | |
| self.failure_count = self.failure_threshold # Force open state | |
| self.last_failure_time = time.time() | |
| logger.warning("Circuit breaker opened after failed attempt in HALF_OPEN state") | |
| raise | |
| elif self.state == CircuitState.CLOSED: | |
| try: | |
| result = await func(*args, **kwargs) | |
| async with self._lock: | |
| # Reset failure count on success | |
| self.failure_count = 0 | |
| return result | |
| except self.expected_exception as e: | |
| async with self._lock: | |
| self.failure_count += 1 | |
| if self.failure_count >= self.failure_threshold: | |
| self.state = CircuitState.OPEN | |
| self.last_failure_time = time.time() | |
| logger.warning(f"Circuit breaker opened after {self.failure_count} consecutive failures") | |
| else: | |
| logger.warning(f"Circuit breaker failure count: {self.failure_count}/{self.failure_threshold}") | |
| raise e | |
| else: | |
| raise Exception("Circuit breaker state unknown") | |
| # Global circuit breaker instance for Kafka connections | |
| kafka_circuit_breaker = CircuitBreaker( | |
| failure_threshold=3, | |
| recovery_timeout=30.0, | |
| expected_exception=(Exception,) | |
| ) |