File size: 3,169 Bytes
34e27fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import asyncio
import time
from enum import Enum
from typing import Callable, Any, Awaitable
import logging

logger = logging.getLogger(__name__)


class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Trip when failures exceed threshold
    HALF_OPEN = "half_open"  # Test if service recovered


class CircuitBreaker:
    def __init__(self, failure_threshold: int = 5, recovery_timeout: float = 60.0, expected_exception: tuple = (Exception,)):
        self.failure_threshold = failure_threshold
        self.recovery_timeout = recovery_timeout
        self.expected_exception = expected_exception

        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.last_failure_time = None
        self._lock = asyncio.Lock()

    async def call(self, func: Callable[..., Awaitable[Any]], *args, **kwargs) -> Any:
        async with self._lock:
            if self.state == CircuitState.OPEN:
                if time.time() - self.last_failure_time >= self.recovery_timeout:
                    self.state = CircuitState.HALF_OPEN
                    logger.info("Circuit breaker transitioning to HALF_OPEN")
                else:
                    raise Exception("Circuit breaker is OPEN")

        if self.state == CircuitState.HALF_OPEN:
            try:
                result = await func(*args, **kwargs)
                async with self._lock:
                    self.state = CircuitState.CLOSED
                    self.failure_count = 0
                    logger.info("Circuit breaker closed after successful call")
                return result
            except self.expected_exception:
                async with self._lock:
                    self.state = CircuitState.OPEN
                    self.failure_count = self.failure_threshold  # Force open state
                    self.last_failure_time = time.time()
                    logger.warning("Circuit breaker opened after failed attempt in HALF_OPEN state")
                raise
        elif self.state == CircuitState.CLOSED:
            try:
                result = await func(*args, **kwargs)
                async with self._lock:
                    # Reset failure count on success
                    self.failure_count = 0
                return result
            except self.expected_exception as e:
                async with self._lock:
                    self.failure_count += 1
                    if self.failure_count >= self.failure_threshold:
                        self.state = CircuitState.OPEN
                        self.last_failure_time = time.time()
                        logger.warning(f"Circuit breaker opened after {self.failure_count} consecutive failures")
                    else:
                        logger.warning(f"Circuit breaker failure count: {self.failure_count}/{self.failure_threshold}")
                raise e
        else:
            raise Exception("Circuit breaker state unknown")


# Global circuit breaker instance for Kafka connections
kafka_circuit_breaker = CircuitBreaker(
    failure_threshold=3,
    recovery_timeout=30.0,
    expected_exception=(Exception,)
)