"""Retry utilities for CVAT API requests. This module provides retry logic with exponential backoff and special handling for 502 Bad Gateway errors, which are common when CVAT server is overloaded. """ import logging import random import time from collections.abc import Callable from functools import wraps from typing import TypeVar import requests T = TypeVar("T") def calculate_retry_delay( attempt: int, base_delay: float, exponential_base: float, max_delay: float, jitter: bool = True, ) -> float: """Calculate the delay for the next retry attempt. Args: attempt: The current attempt number (0-based) base_delay: Initial delay in seconds exponential_base: Base for exponential backoff max_delay: Maximum allowed delay jitter: Whether to add random jitter Returns: The calculated delay in seconds """ delay = min(base_delay * (exponential_base**attempt), max_delay) if jitter: # Add jitter between 50% and 100% of the delay delay = delay * (0.5 + random.random() * 0.5) return delay def is_transient_error(exception: Exception) -> bool: """Check if the exception is a transient error (500, 502, 503, 504). Args: exception: The exception to check Returns: True if it's a transient server error, False otherwise """ if not isinstance(exception, requests.HTTPError): return False if exception.response is None: return False return exception.response.status_code in (500, 502, 503, 504) def should_retry_error(exception: Exception) -> bool: """Determine if an error should be retried. Args: exception: The exception to check Returns: True if the error should be retried, False otherwise """ # Network errors - always retry if isinstance(exception, (requests.Timeout, requests.ConnectionError)): return True # HTTP errors - check status code if isinstance(exception, requests.HTTPError) and exception.response is not None: status_code = exception.response.status_code # Don't retry client errors except rate limiting if 400 <= status_code < 500 and status_code != 429: return False # Retry server errors and rate limiting return True # Other request exceptions - retry if isinstance(exception, requests.RequestException): return True return False def log_retry_attempt( logger: logging.Logger, func_name: str, exception: Exception, attempt: int, max_attempts: int, delay: float, is_transient: bool = False, ) -> None: """Log a retry attempt with appropriate context. Args: logger: Logger instance func_name: Name of the function being retried exception: The exception that triggered the retry attempt: Current attempt number (1-based) max_attempts: Maximum number of attempts delay: Delay before next retry is_transient: Whether this is a transient server error (500/502/503/504) """ if is_transient: status_code = ( exception.response.status_code if hasattr(exception, "response") and exception.response else "unknown" ) logger.warning( "⚠️ %s error (attempt %d/%d), CVAT server may be overloaded. Waiting %.1fs before retry...", status_code, attempt, max_attempts, delay ) else: logger.warning( "⚠️ Request failed (attempt %d/%d), retrying in %.1fs... Error: %s", attempt, max_attempts, delay, exception ) def log_max_retries_reached( logger: logging.Logger, func_name: str, max_retries: int, is_transient: bool = False, ) -> None: """Log when maximum retries have been reached. Args: logger: Logger instance func_name: Name of the function max_retries: Maximum retries that were attempted is_transient: Whether this was for transient errors specifically """ if is_transient: logger.error( "❌ Max transient error retries (%d) reached for %s. " "CVAT server appears to be having persistent issues.", max_retries, func_name ) else: logger.error("❌ Max retries (%d) reached for %s", max_retries, func_name) def retry_with_backoff( max_retries: int = 3, initial_delay: float = 1.0, max_delay: float = 60.0, exponential_base: float = 2.0, jitter: bool = True, # Transient error (500/502/503/504) specific parameters retry_transient: bool = True, max_transient_retries: int = 10, initial_transient_delay: float = 10.0, max_transient_delay: float = 300.0, ): """Decorator that implements retry logic with exponential backoff. Special handling for transient server errors (500/502/503/504) with longer delays. These errors are common when CVAT server is overloaded or having temporary issues. Args: max_retries: Maximum retry attempts for general errors initial_delay: Initial delay between retries in seconds max_delay: Maximum delay between retries in seconds exponential_base: Base for exponential backoff calculation jitter: Whether to add random jitter to prevent thundering herd retry_transient: Whether to apply special retry logic for transient errors max_transient_retries: Maximum retries specifically for transient errors initial_transient_delay: Initial delay for transient errors (longer than general) max_transient_delay: Maximum delay for transient errors """ def decorator(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(self, *args, **kwargs) -> T: return _execute_with_retry( func, self, args, kwargs, max_retries, initial_delay, max_delay, exponential_base, jitter, retry_transient, max_transient_retries, initial_transient_delay, max_transient_delay, ) return wrapper return decorator def _execute_with_retry( func: Callable[..., T], self, args: tuple, kwargs: dict, max_retries: int, initial_delay: float, max_delay: float, exponential_base: float, jitter: bool, retry_transient: bool, max_transient_retries: int, initial_transient_delay: float, max_transient_delay: float, ) -> T: """Execute a function with retry logic. This is separated from the decorator to keep functions under 20 lines. """ last_exception = None consecutive_transient_errors = 0 # Determine max attempts based on transient error handling effective_max_retries = ( max(max_retries, max_transient_retries) if retry_transient else max_retries ) for attempt in range(effective_max_retries + 1): try: result = func(self, *args, **kwargs) return result except Exception as e: last_exception = e # Handle the retry logic retry_result = _handle_retry_exception( e, self.client.logger if hasattr(self, "client") else self.logger, func.__name__, attempt, consecutive_transient_errors, max_retries, initial_delay, max_delay, exponential_base, jitter, retry_transient, max_transient_retries, initial_transient_delay, max_transient_delay, ) if retry_result is None: # Should not retry raise consecutive_transient_errors = retry_result["consecutive_transient_errors"] delay = retry_result["delay"] time.sleep(delay) # This should never be reached, but just in case raise last_exception if last_exception else RuntimeError("Unexpected retry error") def _handle_retry_exception( exception: Exception, logger: logging.Logger, func_name: str, attempt: int, consecutive_transient_errors: int, max_retries: int, initial_delay: float, max_delay: float, exponential_base: float, jitter: bool, retry_transient: bool, max_transient_retries: int, initial_transient_delay: float, max_transient_delay: float, ) -> dict[str, any] | None: """Handle an exception during retry logic. Returns: Dict with retry info if should retry, None if should not retry """ logger.error( "❌ Request failed with error: %s. Attempt %d/%d.", exception, attempt + 1, max(max_retries, max_transient_retries) + 1 ) # Check if it's a transient error (500/502/503/504) if is_transient_error(exception) and retry_transient: consecutive_transient_errors += 1 if consecutive_transient_errors > max_transient_retries: log_max_retries_reached( logger, func_name, max_transient_retries, is_transient=True ) return None # Calculate transient error-specific delay delay = calculate_retry_delay( consecutive_transient_errors - 1, initial_transient_delay, exponential_base, max_transient_delay, jitter, ) log_retry_attempt( logger, func_name, exception, consecutive_transient_errors, max_transient_retries + 1, delay, is_transient=True, ) return { "consecutive_transient_errors": consecutive_transient_errors, "delay": delay, } # Not a transient error - check if we should retry if not should_retry_error(exception): return None # Check if we've exceeded general retries if attempt >= max_retries: log_max_retries_reached(logger, func_name, max_retries) return None # Calculate regular delay delay = calculate_retry_delay( attempt, initial_delay, exponential_base, max_delay, jitter ) log_retry_attempt(logger, func_name, exception, attempt + 1, max_retries + 1, delay) return {"consecutive_transient_errors": 0, "delay": delay}