| """ |
| Reverse retry utilities. |
| """ |
|
|
| import asyncio |
| import inspect |
| import random |
| from typing import Callable, Any, Optional |
|
|
| from app.core.logger import logger |
| from app.core.config import get_config |
| from app.core.exceptions import UpstreamException |
|
|
|
|
| class RetryContext: |
| """Retry context.""" |
|
|
| def __init__(self): |
| self.attempt = 0 |
| self.max_retry = int(get_config("retry.max_retry")) |
| self.retry_codes = get_config("retry.retry_status_codes") |
| self.last_error = None |
| self.last_status = None |
| self.total_delay = 0.0 |
| self.retry_budget = float(get_config("retry.retry_budget")) |
|
|
| |
| self.backoff_base = float(get_config("retry.retry_backoff_base")) |
| self.backoff_factor = float(get_config("retry.retry_backoff_factor")) |
| self.backoff_max = float(get_config("retry.retry_backoff_max")) |
|
|
| |
| self._last_delay = self.backoff_base |
|
|
| def should_retry(self, status_code: int) -> bool: |
| """Check if should retry.""" |
| if self.attempt >= self.max_retry: |
| return False |
| if status_code not in self.retry_codes: |
| return False |
| if self.total_delay >= self.retry_budget: |
| return False |
| return True |
|
|
| def record_error(self, status_code: int, error: Exception): |
| """Record error information.""" |
| self.last_status = status_code |
| self.last_error = error |
| self.attempt += 1 |
|
|
| def calculate_delay(self, status_code: int, retry_after: Optional[float] = None) -> float: |
| """ |
| Calculate backoff delay time. |
| |
| Args: |
| status_code: HTTP status code |
| retry_after: Retry-After header value (seconds) |
| |
| Returns: |
| Delay time (seconds) |
| """ |
| |
| if retry_after is not None and retry_after > 0: |
| delay = min(retry_after, self.backoff_max) |
| self._last_delay = delay |
| return delay |
|
|
| |
| if status_code == 429: |
| |
| delay = random.uniform(self.backoff_base, self._last_delay * 3) |
| delay = min(delay, self.backoff_max) |
| self._last_delay = delay |
| return delay |
|
|
| |
| exp_delay = self.backoff_base * (self.backoff_factor**self.attempt) |
| delay = random.uniform(0, min(exp_delay, self.backoff_max)) |
| return delay |
|
|
| def record_delay(self, delay: float): |
| """Record delay time.""" |
| self.total_delay += delay |
|
|
|
|
| def extract_retry_after(error: Exception) -> Optional[float]: |
| """ |
| Extract Retry-After value from exception. |
| |
| Args: |
| error: Exception object |
| |
| Returns: |
| Retry-After value (seconds), or None |
| """ |
| if not isinstance(error, UpstreamException): |
| return None |
|
|
| details = error.details or {} |
|
|
| |
| retry_after = details.get("retry_after") |
| if retry_after is not None: |
| try: |
| return float(retry_after) |
| except (ValueError, TypeError): |
| pass |
|
|
| |
| headers = details.get("headers", {}) |
| if isinstance(headers, dict): |
| retry_after = headers.get("Retry-After") or headers.get("retry-after") |
| if retry_after is not None: |
| try: |
| return float(retry_after) |
| except (ValueError, TypeError): |
| pass |
|
|
| return None |
|
|
|
|
| async def retry_on_status( |
| func: Callable, |
| *args, |
| extract_status: Callable[[Exception], Optional[int]] = None, |
| on_retry: Callable[[int, int, Exception, float], Any] = None, |
| **kwargs, |
| ) -> Any: |
| """ |
| Generic retry function. |
| |
| Args: |
| func: Retry function |
| *args: Function arguments |
| extract_status: Function to extract status code from exception |
| on_retry: Callback function for retry (attempt, status_code, error, delay). |
| Can be sync or async. |
| **kwargs: Function keyword arguments |
| |
| Returns: |
| Function execution result |
| |
| Raises: |
| Last failed exception |
| """ |
| ctx = RetryContext() |
|
|
| |
| if extract_status is None: |
|
|
| def extract_status(e: Exception) -> Optional[int]: |
| if isinstance(e, UpstreamException): |
| |
| if e.details and "status" in e.details: |
| return e.details["status"] |
| return getattr(e, "status_code", None) |
| return None |
|
|
| while ctx.attempt <= ctx.max_retry: |
| try: |
| result = await func(*args, **kwargs) |
|
|
| |
| if ctx.attempt > 0: |
| logger.info( |
| f"Retry succeeded after {ctx.attempt} attempts, " |
| f"total delay: {ctx.total_delay:.2f}s" |
| ) |
|
|
| return result |
|
|
| except Exception as e: |
| |
| status_code = extract_status(e) |
|
|
| if status_code is None: |
| |
| logger.error(f"Non-retryable error: {e}") |
| raise |
|
|
| |
| ctx.record_error(status_code, e) |
|
|
| |
| if ctx.should_retry(status_code): |
| |
| retry_after = extract_retry_after(e) |
|
|
| |
| delay = ctx.calculate_delay(status_code, retry_after) |
|
|
| |
| if ctx.total_delay + delay > ctx.retry_budget: |
| logger.warning( |
| f"Retry budget exhausted: {ctx.total_delay:.2f}s + {delay:.2f}s > {ctx.retry_budget}s" |
| ) |
| raise |
|
|
| ctx.record_delay(delay) |
|
|
| logger.warning( |
| f"Retry {ctx.attempt}/{ctx.max_retry} for status {status_code}, " |
| f"waiting {delay:.2f}s (total: {ctx.total_delay:.2f}s)" |
| + (f", Retry-After: {retry_after}s" if retry_after else "") |
| ) |
|
|
| |
| if on_retry: |
| result = on_retry(ctx.attempt, status_code, e, delay) |
| if inspect.isawaitable(result): |
| await result |
|
|
| await asyncio.sleep(delay) |
| continue |
| else: |
| |
| if status_code in ctx.retry_codes: |
| logger.error( |
| f"Retry exhausted after {ctx.attempt} attempts, " |
| f"last status: {status_code}, total delay: {ctx.total_delay:.2f}s" |
| ) |
| else: |
| logger.error(f"Non-retryable status code: {status_code}") |
|
|
| |
| raise |
|
|
|
|
| __all__ = [ |
| "RetryContext", |
| "retry_on_status", |
| "extract_retry_after", |
| ] |
|
|