Taskflow-App / src /utils /circuit_breaker.py
Tahasaif3's picture
'code
34e27fb
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Awaitable
import logging
logger = logging.getLogger(__name__)
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Trip when failures exceed threshold
HALF_OPEN = "half_open" # Test if service recovered
class CircuitBreaker:
def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: tuple = (Exception,)):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.state = CircuitState.CLOSED
self.failure_count = 0
self.last_failure_time = None
self._lock = asyncio.Lock()
async def call(self, func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
async with self._lock:
if self.state == CircuitState.OPEN:
if time.time() - self.last_failure_time >= self.recovery_timeout:
self.state = CircuitState.HALF_OPEN
logger.info("Circuit breaker transitioning to HALF_OPEN")
else:
raise Exception("Circuit breaker is OPEN")
if self.state == CircuitState.HALF_OPEN:
try:
result = await func(*args, **kwargs)
async with self._lock:
self.state = CircuitState.CLOSED
self.failure_count = 0
logger.info("Circuit breaker closed after successful call")
return result
except self.expected_exception:
async with self._lock:
self.state = CircuitState.OPEN
self.failure_count = self.failure_threshold # Force open state
self.last_failure_time = time.time()
logger.warning("Circuit breaker opened after failed attempt in HALF_OPEN state")
raise
elif self.state == CircuitState.CLOSED:
try:
result = await func(*args, **kwargs)
async with self._lock:
# Reset failure count on success
self.failure_count = 0
return result
except self.expected_exception as e:
async with self._lock:
self.failure_count += 1
if self.failure_count >= self.failure_threshold:
self.state = CircuitState.OPEN
self.last_failure_time = time.time()
logger.warning(f"Circuit breaker opened after {self.failure_count} consecutive failures")
else:
logger.warning(f"Circuit breaker failure count: {self.failure_count}/{self.failure_threshold}")
raise e
else:
raise Exception("Circuit breaker state unknown")
# Global circuit breaker instance for Kafka connections
kafka_circuit_breaker = CircuitBreaker(
failure_threshold=3,
recovery_timeout=30.0,
expected_exception=(Exception,)
)