Thibaut's picture
Add complete metrics evaluation subproject structure
b7d2408
"""Retry utilities for CVAT API requests.
This module provides retry logic with exponential backoff and special handling
for 502 Bad Gateway errors, which are common when CVAT server is overloaded.
"""
import logging
import random
import time
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
import requests
T = TypeVar("T")
def calculate_retry_delay(
attempt: int,
base_delay: float,
exponential_base: float,
max_delay: float,
jitter: bool = True,
) -> float:
"""Calculate the delay for the next retry attempt.
Args:
attempt: The current attempt number (0-based)
base_delay: Initial delay in seconds
exponential_base: Base for exponential backoff
max_delay: Maximum allowed delay
jitter: Whether to add random jitter
Returns:
The calculated delay in seconds
"""
delay = min(base_delay * (exponential_base**attempt), max_delay)
if jitter:
# Add jitter between 50% and 100% of the delay
delay = delay * (0.5 + random.random() * 0.5)
return delay
def is_transient_error(exception: Exception) -> bool:
"""Check if the exception is a transient error (500, 502, 503, 504).
Args:
exception: The exception to check
Returns:
True if it's a transient server error, False otherwise
"""
if not isinstance(exception, requests.HTTPError):
return False
if exception.response is None:
return False
return exception.response.status_code in (500, 502, 503, 504)
def should_retry_error(exception: Exception) -> bool:
"""Determine if an error should be retried.
Args:
exception: The exception to check
Returns:
True if the error should be retried, False otherwise
"""
# Network errors - always retry
if isinstance(exception, (requests.Timeout, requests.ConnectionError)):
return True
# HTTP errors - check status code
if isinstance(exception, requests.HTTPError) and exception.response is not None:
status_code = exception.response.status_code
# Don't retry client errors except rate limiting
if 400 <= status_code < 500 and status_code != 429:
return False
# Retry server errors and rate limiting
return True
# Other request exceptions - retry
if isinstance(exception, requests.RequestException):
return True
return False
def log_retry_attempt(
logger: logging.Logger,
func_name: str,
exception: Exception,
attempt: int,
max_attempts: int,
delay: float,
is_transient: bool = False,
) -> None:
"""Log a retry attempt with appropriate context.
Args:
logger: Logger instance
func_name: Name of the function being retried
exception: The exception that triggered the retry
attempt: Current attempt number (1-based)
max_attempts: Maximum number of attempts
delay: Delay before next retry
is_transient: Whether this is a transient server error (500/502/503/504)
"""
if is_transient:
status_code = (
exception.response.status_code
if hasattr(exception, "response") and exception.response
else "unknown"
)
logger.warning(
"⚠️ %s error (attempt %d/%d), CVAT server may be overloaded. Waiting %.1fs before retry...",
status_code, attempt, max_attempts, delay
)
else:
logger.warning(
"⚠️ Request failed (attempt %d/%d), retrying in %.1fs... Error: %s",
attempt, max_attempts, delay, exception
)
def log_max_retries_reached(
logger: logging.Logger,
func_name: str,
max_retries: int,
is_transient: bool = False,
) -> None:
"""Log when maximum retries have been reached.
Args:
logger: Logger instance
func_name: Name of the function
max_retries: Maximum retries that were attempted
is_transient: Whether this was for transient errors specifically
"""
if is_transient:
logger.error(
"❌ Max transient error retries (%d) reached for %s. "
"CVAT server appears to be having persistent issues.",
max_retries, func_name
)
else:
logger.error("❌ Max retries (%d) reached for %s", max_retries, func_name)
def retry_with_backoff(
max_retries: int = 3,
initial_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
# Transient error (500/502/503/504) specific parameters
retry_transient: bool = True,
max_transient_retries: int = 10,
initial_transient_delay: float = 10.0,
max_transient_delay: float = 300.0,
):
"""Decorator that implements retry logic with exponential backoff.
Special handling for transient server errors (500/502/503/504) with longer delays.
These errors are common when CVAT server is overloaded or having temporary issues.
Args:
max_retries: Maximum retry attempts for general errors
initial_delay: Initial delay between retries in seconds
max_delay: Maximum delay between retries in seconds
exponential_base: Base for exponential backoff calculation
jitter: Whether to add random jitter to prevent thundering herd
retry_transient: Whether to apply special retry logic for transient errors
max_transient_retries: Maximum retries specifically for transient errors
initial_transient_delay: Initial delay for transient errors (longer than general)
max_transient_delay: Maximum delay for transient errors
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def wrapper(self, *args, **kwargs) -> T:
return _execute_with_retry(
func,
self,
args,
kwargs,
max_retries,
initial_delay,
max_delay,
exponential_base,
jitter,
retry_transient,
max_transient_retries,
initial_transient_delay,
max_transient_delay,
)
return wrapper
return decorator
def _execute_with_retry(
func: Callable[..., T],
self,
args: tuple,
kwargs: dict,
max_retries: int,
initial_delay: float,
max_delay: float,
exponential_base: float,
jitter: bool,
retry_transient: bool,
max_transient_retries: int,
initial_transient_delay: float,
max_transient_delay: float,
) -> T:
"""Execute a function with retry logic.
This is separated from the decorator to keep functions under 20 lines.
"""
last_exception = None
consecutive_transient_errors = 0
# Determine max attempts based on transient error handling
effective_max_retries = (
max(max_retries, max_transient_retries) if retry_transient else max_retries
)
for attempt in range(effective_max_retries + 1):
try:
result = func(self, *args, **kwargs)
return result
except Exception as e:
last_exception = e
# Handle the retry logic
retry_result = _handle_retry_exception(
e,
self.client.logger if hasattr(self, "client") else self.logger,
func.__name__,
attempt,
consecutive_transient_errors,
max_retries,
initial_delay,
max_delay,
exponential_base,
jitter,
retry_transient,
max_transient_retries,
initial_transient_delay,
max_transient_delay,
)
if retry_result is None:
# Should not retry
raise
consecutive_transient_errors = retry_result["consecutive_transient_errors"]
delay = retry_result["delay"]
time.sleep(delay)
# This should never be reached, but just in case
raise last_exception if last_exception else RuntimeError("Unexpected retry error")
def _handle_retry_exception(
exception: Exception,
logger: logging.Logger,
func_name: str,
attempt: int,
consecutive_transient_errors: int,
max_retries: int,
initial_delay: float,
max_delay: float,
exponential_base: float,
jitter: bool,
retry_transient: bool,
max_transient_retries: int,
initial_transient_delay: float,
max_transient_delay: float,
) -> dict[str, any] | None:
"""Handle an exception during retry logic.
Returns:
Dict with retry info if should retry, None if should not retry
"""
logger.error(
"❌ Request failed with error: %s. Attempt %d/%d.",
exception, attempt + 1, max(max_retries, max_transient_retries) + 1
)
# Check if it's a transient error (500/502/503/504)
if is_transient_error(exception) and retry_transient:
consecutive_transient_errors += 1
if consecutive_transient_errors > max_transient_retries:
log_max_retries_reached(
logger, func_name, max_transient_retries, is_transient=True
)
return None
# Calculate transient error-specific delay
delay = calculate_retry_delay(
consecutive_transient_errors - 1,
initial_transient_delay,
exponential_base,
max_transient_delay,
jitter,
)
log_retry_attempt(
logger,
func_name,
exception,
consecutive_transient_errors,
max_transient_retries + 1,
delay,
is_transient=True,
)
return {
"consecutive_transient_errors": consecutive_transient_errors,
"delay": delay,
}
# Not a transient error - check if we should retry
if not should_retry_error(exception):
return None
# Check if we've exceeded general retries
if attempt >= max_retries:
log_max_retries_reached(logger, func_name, max_retries)
return None
# Calculate regular delay
delay = calculate_retry_delay(
attempt, initial_delay, exponential_base, max_delay, jitter
)
log_retry_attempt(logger, func_name, exception, attempt + 1, max_retries + 1, delay)
return {"consecutive_transient_errors": 0, "delay": delay}