VibecoderMcSwaggins's picture
refactor: implement proper middleware architecture (SPEC-21)
ae5413a
raw
history blame
3.42 kB
"""Retry middleware for chat clients with exponential backoff."""
import asyncio
from collections.abc import Awaitable, Callable
import structlog
from agent_framework._middleware import ChatContext, ChatMiddleware
logger = structlog.get_logger()
class RetryMiddleware(ChatMiddleware):
"""Retries failed chat requests with exponential backoff.
This middleware intercepts chat client calls and retries on transient
errors (rate limits, timeouts, server errors).
Attributes:
max_attempts: Maximum number of attempts (default: 3)
min_wait: Minimum wait between retries in seconds (default: 1.0)
max_wait: Maximum wait between retries in seconds (default: 10.0)
retryable_status_codes: HTTP status codes to retry (default: 429, 500, 502, 503, 504)
"""
def __init__(
self,
max_attempts: int = 3,
min_wait: float = 1.0,
max_wait: float = 10.0,
retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
) -> None:
self.max_attempts = max_attempts
self.min_wait = min_wait
self.max_wait = max_wait
self.retryable_status_codes = retryable_status_codes
def _is_retryable(self, error: Exception) -> bool:
"""Check if error is retryable."""
# Check for httpx status errors
if hasattr(error, "response") and hasattr(error.response, "status_code"):
return error.response.status_code in self.retryable_status_codes
# Check for timeout errors
error_name = type(error).__name__.lower()
if "timeout" in error_name:
return True
# Check for connection errors
if "connection" in error_name:
return True
return False
def _calculate_wait(self, attempt: int) -> float:
"""Calculate wait time with exponential backoff."""
wait = self.min_wait * (2**attempt)
return float(min(wait, self.max_wait))
async def process(
self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
) -> None:
"""Process the chat request with retry logic."""
last_error: Exception | None = None
for attempt in range(self.max_attempts):
try:
await next(context)
return # Success - exit retry loop
except Exception as e:
last_error = e
if not self._is_retryable(e):
logger.warning(
"Non-retryable error",
error=str(e),
error_type=type(e).__name__,
)
raise # Don't retry non-retryable errors
if attempt < self.max_attempts - 1:
wait_time = self._calculate_wait(attempt)
logger.info(
"Retrying after error",
attempt=attempt + 1,
max_attempts=self.max_attempts,
wait_seconds=wait_time,
error=str(e),
)
await asyncio.sleep(wait_time)
# All retries exhausted
logger.error(
"All retry attempts failed",
max_attempts=self.max_attempts,
last_error=str(last_error),
)
if last_error:
raise last_error