import asyncio import time from enum import Enum from typing import Callable, Any, Awaitable import logging logger = logging.getLogger(__name__) class CircuitState(Enum): CLOSED = "closed" # Normal operation OPEN = "open" # Trip when failures exceed threshold HALF_OPEN = "half_open" # Test if service recovered class CircuitBreaker: def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: tuple = (Exception,)): self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.expected_exception = expected_exception self.state = CircuitState.CLOSED self.failure_count = 0 self.last_failure_time = None self._lock = asyncio.Lock() async def call(self, func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any: async with self._lock: if self.state == CircuitState.OPEN: if time.time() - self.last_failure_time >= self.recovery_timeout: self.state = CircuitState.HALF_OPEN logger.info("Circuit breaker transitioning to HALF_OPEN") else: raise Exception("Circuit breaker is OPEN") if self.state == CircuitState.HALF_OPEN: try: result = await func(*args, **kwargs) async with self._lock: self.state = CircuitState.CLOSED self.failure_count = 0 logger.info("Circuit breaker closed after successful call") return result except self.expected_exception: async with self._lock: self.state = CircuitState.OPEN self.failure_count = self.failure_threshold # Force open state self.last_failure_time = time.time() logger.warning("Circuit breaker opened after failed attempt in HALF_OPEN state") raise elif self.state == CircuitState.CLOSED: try: result = await func(*args, **kwargs) async with self._lock: # Reset failure count on success self.failure_count = 0 return result except self.expected_exception as e: async with self._lock: self.failure_count += 1 if self.failure_count >= self.failure_threshold: self.state = CircuitState.OPEN self.last_failure_time = time.time() logger.warning(f"Circuit breaker opened after {self.failure_count} consecutive failures") else: logger.warning(f"Circuit breaker failure count: {self.failure_count}/{self.failure_threshold}") raise e else: raise Exception("Circuit breaker state unknown") # Global circuit breaker instance for Kafka connections kafka_circuit_breaker = CircuitBreaker( failure_threshold=3, recovery_timeout=30.0, expected_exception=(Exception,) )