File size: 3,415 Bytes
ae5413a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Retry middleware for chat clients with exponential backoff."""

import asyncio
from collections.abc import Awaitable, Callable

import structlog
from agent_framework._middleware import ChatContext, ChatMiddleware

logger = structlog.get_logger()


class RetryMiddleware(ChatMiddleware):
    """Retries failed chat requests with exponential backoff.

    This middleware intercepts chat client calls and retries on transient
    errors (rate limits, timeouts, server errors).

    Attributes:
        max_attempts: Maximum number of attempts (default: 3)
        min_wait: Minimum wait between retries in seconds (default: 1.0)
        max_wait: Maximum wait between retries in seconds (default: 10.0)
        retryable_status_codes: HTTP status codes to retry (default: 429, 500, 502, 503, 504)
    """

    def __init__(
        self,
        max_attempts: int = 3,
        min_wait: float = 1.0,
        max_wait: float = 10.0,
        retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
    ) -> None:
        self.max_attempts = max_attempts
        self.min_wait = min_wait
        self.max_wait = max_wait
        self.retryable_status_codes = retryable_status_codes

    def _is_retryable(self, error: Exception) -> bool:
        """Check if error is retryable."""
        # Check for httpx status errors
        if hasattr(error, "response") and hasattr(error.response, "status_code"):
            return error.response.status_code in self.retryable_status_codes

        # Check for timeout errors
        error_name = type(error).__name__.lower()
        if "timeout" in error_name:
            return True

        # Check for connection errors
        if "connection" in error_name:
            return True

        return False

    def _calculate_wait(self, attempt: int) -> float:
        """Calculate wait time with exponential backoff."""
        wait = self.min_wait * (2**attempt)
        return float(min(wait, self.max_wait))

    async def process(
        self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
    ) -> None:
        """Process the chat request with retry logic."""
        last_error: Exception | None = None

        for attempt in range(self.max_attempts):
            try:
                await next(context)
                return  # Success - exit retry loop

            except Exception as e:
                last_error = e

                if not self._is_retryable(e):
                    logger.warning(
                        "Non-retryable error",
                        error=str(e),
                        error_type=type(e).__name__,
                    )
                    raise  # Don't retry non-retryable errors

                if attempt < self.max_attempts - 1:
                    wait_time = self._calculate_wait(attempt)
                    logger.info(
                        "Retrying after error",
                        attempt=attempt + 1,
                        max_attempts=self.max_attempts,
                        wait_seconds=wait_time,
                        error=str(e),
                    )
                    await asyncio.sleep(wait_time)

        # All retries exhausted
        logger.error(
            "All retry attempts failed",
            max_attempts=self.max_attempts,
            last_error=str(last_error),
        )
        if last_error:
            raise last_error