MorphGuard / src /api_utils.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
55.3 kB
"""
Enhanced API utilities for MorphGuard.
This module extends the existing API client with advanced features for resilient API calls:
- Smart request queuing with priority handling
- Comprehensive error handling with detailed diagnostics
- Intelligent caching strategy with stale-while-revalidate support
- Circuit breaker pattern implementation
- Extended metrics and monitoring
- Context-aware request retry strategies
"""
import os
import time
import json
import logging
import threading
import hashlib
import queue
from typing import Dict, Any, List, Optional, Union, Callable, Tuple
from datetime import datetime
from functools import wraps
import requests
from urllib.parse import urljoin
# Import error handling for proper error classification
from src.error_handling import (
MGError, ErrorCode, ErrorSeverity, ErrorCategory,
APIError, NetworkError, handle_exception
)
# Import telemetry for metrics tracking
from src.telemetry import get_telemetry, EventCategory
# Core module for consistent initialization
from src.core import get_core
class CircuitBreaker:
"""
Circuit breaker for preventing repeated calls to failing endpoints.
Implements the circuit breaker pattern with three states:
- CLOSED: Normal operation, requests are sent normally
- OPEN: Circuit is tripped, all requests fail fast
- HALF_OPEN: Testing if the service has recovered
"""
# Circuit states
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half-open"
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
half_open_max_calls: int = 1
):
"""
Initialize the circuit breaker.
Args:
failure_threshold: Number of failures before opening the circuit
recovery_timeout: Time in seconds before attempting recovery
half_open_max_calls: Maximum calls allowed in half-open state
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_max_calls = half_open_max_calls
# Internal state
self.state = self.CLOSED
self.failure_count = 0
self.last_failure_time = 0
self.half_open_calls = 0
# Lock for thread safety
self.lock = threading.RLock()
def record_success(self):
"""Record a successful call."""
with self.lock:
self.failure_count = 0
if self.state != self.CLOSED:
self.state = self.CLOSED
self.half_open_calls = 0
def record_failure(self):
"""Record a failed call."""
with self.lock:
self.last_failure_time = time.time()
if self.state == self.CLOSED:
self.failure_count += 1
if self.failure_count >= self.failure_threshold:
self.state = self.OPEN
elif self.state == self.HALF_OPEN:
self.state = self.OPEN
self.half_open_calls = 0
def allow_request(self) -> bool:
"""
Check if a request should be allowed.
Returns:
True if the request should be allowed, False otherwise
"""
with self.lock:
if self.state == self.CLOSED:
return True
elif self.state == self.OPEN:
# Check if recovery timeout has elapsed
if time.time() - self.last_failure_time >= self.recovery_timeout:
self.state = self.HALF_OPEN
self.half_open_calls = 0
return self.allow_request() # Re-check with new state
return False
elif self.state == self.HALF_OPEN:
# Allow limited calls in half-open state
if self.half_open_calls < self.half_open_max_calls:
self.half_open_calls += 1
return True
return False
return True # Default fallback
def get_state(self) -> str:
"""Get the current state of the circuit breaker."""
with self.lock:
return self.state
def get_failure_count(self) -> int:
"""Get the current failure count."""
with self.lock:
return self.failure_count
def reset(self):
"""Reset the circuit breaker to closed state."""
with self.lock:
self.state = self.CLOSED
self.failure_count = 0
self.half_open_calls = 0
self.last_failure_time = 0
class StaleWhileRevalidateCache:
"""
Enhanced cache with stale-while-revalidate strategy.
Allows returning stale data while asynchronously fetching fresh data.
This reduces perceived latency and provides better availability.
"""
def __init__(
self,
cache_dir: str,
max_entries: int = 100,
stale_timeout: float = 300.0, # 5 minutes
revalidation_timeout: float = 60.0, # 1 minute
background_refresh: bool = True
):
"""
Initialize the cache.
Args:
cache_dir: Directory to store cache files
max_entries: Maximum number of in-memory cache entries
stale_timeout: Time in seconds before data is considered stale
revalidation_timeout: Maximum age for using stale data during revalidation
background_refresh: Whether to refresh cache in background
"""
self.cache_dir = cache_dir
self.max_entries = max_entries
self.stale_timeout = stale_timeout
self.revalidation_timeout = revalidation_timeout
self.background_refresh = background_refresh
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# In-memory cache
self.cache = {}
# Revalidation queue and thread
self.revalidation_queue = queue.Queue()
self.revalidation_thread = None
# Lock for thread safety
self.lock = threading.RLock()
# Start background thread if enabled
if background_refresh:
self._start_revalidation_thread()
def _start_revalidation_thread(self):
"""Start the background thread for revalidation."""
self.revalidation_thread = threading.Thread(
target=self._revalidation_worker,
daemon=True
)
self.revalidation_thread.start()
def _revalidation_worker(self):
"""Background worker to process revalidation queue."""
while True:
try:
# Get the next revalidation task
task = self.revalidation_queue.get()
if task is None: # Sentinel value to stop thread
break
# Unpack task
key, refresh_callback = task
# Call the refresh callback
try:
fresh_data = refresh_callback()
if fresh_data is not None:
self.put(key, fresh_data)
except Exception as e:
# Log error but continue processing queue
logging.error(f"Error refreshing cache for {key}: {e}")
# Mark task as done
self.revalidation_queue.task_done()
except Exception as e:
# Log error but keep thread running
logging.error(f"Error in revalidation worker: {e}")
# Sleep to prevent busy waiting
time.sleep(0.1)
def get(
self,
key: str,
refresh_callback: Optional[Callable] = None
) -> Tuple[Any, bool]:
"""
Get an item from the cache.
Args:
key: Cache key
refresh_callback: Callback to refresh the data if stale
Returns:
Tuple of (data, is_fresh)
"""
with self.lock:
# Check in-memory cache first
if key in self.cache:
entry = self.cache[key]
current_time = time.time()
if current_time < entry["fresh_until"]:
# Data is fresh
return entry["data"], True
elif current_time < entry["stale_until"]:
# Data is stale but usable, trigger refresh if callback provided
if refresh_callback and self.background_refresh:
self.revalidation_queue.put((key, refresh_callback))
return entry["data"], False
# Check disk cache
disk_entry = self._get_from_disk(key)
if disk_entry:
# Add to memory cache
self.cache[key] = disk_entry
# Handle memory cache size limit
if len(self.cache) > self.max_entries:
self._evict_oldest_entry()
# Check freshness
current_time = time.time()
if current_time < disk_entry["fresh_until"]:
return disk_entry["data"], True
elif current_time < disk_entry["stale_until"]:
# Trigger refresh if callback provided
if refresh_callback and self.background_refresh:
self.revalidation_queue.put((key, refresh_callback))
return disk_entry["data"], False
# No cache hit or data is too stale
return None, False
def put(self, key: str, data: Any, custom_ttl: Optional[float] = None) -> None:
"""
Put an item in the cache.
Args:
key: Cache key
data: Data to cache
custom_ttl: Optional custom TTL in seconds
"""
current_time = time.time()
fresh_until = current_time + (custom_ttl or self.stale_timeout)
stale_until = fresh_until + self.revalidation_timeout
entry = {
"data": data,
"fresh_until": fresh_until,
"stale_until": stale_until,
"created_at": current_time
}
with self.lock:
# Add to memory cache
self.cache[key] = entry
# Handle memory cache size limit
if len(self.cache) > self.max_entries:
self._evict_oldest_entry()
# Save to disk
self._save_to_disk(key, entry)
def _evict_oldest_entry(self) -> None:
"""Evict the oldest entry from the memory cache."""
if not self.cache:
return
# Find the oldest entry
oldest_key = min(
self.cache.keys(),
key=lambda k: self.cache[k]["created_at"]
)
# Remove from memory cache
del self.cache[oldest_key]
def _get_from_disk(self, key: str) -> Optional[Dict[str, Any]]:
"""
Get an entry from disk cache.
Args:
key: Cache key
Returns:
Cache entry or None
"""
try:
# Create a safe filename from the key
filename = self._key_to_filename(key)
filepath = os.path.join(self.cache_dir, filename)
if not os.path.exists(filepath):
return None
with open(filepath, "r") as f:
return json.load(f)
except Exception as e:
logging.error(f"Error reading from disk cache: {e}")
return None
def _save_to_disk(self, key: str, entry: Dict[str, Any]) -> None:
"""
Save an entry to disk cache.
Args:
key: Cache key
entry: Cache entry
"""
try:
# Create a safe filename from the key
filename = self._key_to_filename(key)
filepath = os.path.join(self.cache_dir, filename)
with open(filepath, "w") as f:
json.dump(entry, f)
except Exception as e:
logging.error(f"Error writing to disk cache: {e}")
def _key_to_filename(self, key: str) -> str:
"""
Convert a cache key to a safe filename.
Args:
key: Cache key
Returns:
Safe filename
"""
# Hash the key to ensure it's a valid filename
hashed = hashlib.md5(key.encode()).hexdigest()
return f"cache_{hashed}.json"
def invalidate(self, key: str) -> None:
"""
Invalidate a cache entry.
Args:
key: Cache key to invalidate
"""
with self.lock:
# Remove from memory cache
if key in self.cache:
del self.cache[key]
# Remove from disk cache
try:
filename = self._key_to_filename(key)
filepath = os.path.join(self.cache_dir, filename)
if os.path.exists(filepath):
os.remove(filepath)
except Exception as e:
logging.error(f"Error removing cache file: {e}")
def clear(self) -> None:
"""Clear all cache entries."""
with self.lock:
# Clear memory cache
self.cache.clear()
# Clear disk cache
try:
for filename in os.listdir(self.cache_dir):
if filename.startswith("cache_") and filename.endswith(".json"):
filepath = os.path.join(self.cache_dir, filename)
os.remove(filepath)
except Exception as e:
logging.error(f"Error clearing disk cache: {e}")
def shutdown(self) -> None:
"""Shutdown the cache, stopping background threads."""
if self.revalidation_thread and self.revalidation_thread.is_alive():
self.revalidation_queue.put(None) # Sentinel to stop thread
self.revalidation_thread.join(timeout=1.0)
class APIUtilsConfig:
"""Configuration for API utilities."""
def __init__(self, options: Dict[str, Any] = None):
"""
Initialize configuration with default values.
Args:
options: Custom configuration options
"""
# Default configuration
self.base_url = "http://localhost:5000/api"
self.timeout = 30.0
self.retries = 3
self.retry_delay = 1.0
self.retry_status_codes = [408, 429, 500, 502, 503, 504]
self.cache_enabled = True
self.cache_dir = ".mg_api_cache"
self.max_cache_entries = 100
self.stale_timeout = 300.0 # 5 minutes
self.revalidation_timeout = 600.0 # 10 minutes
self.background_refresh = True
self.circuit_breaker_enabled = True
self.failure_threshold = 5
self.recovery_timeout = 30.0
self.priority_queue_enabled = True
self.max_queue_size = 100
self.compression_enabled = True
self.debug = False
# Update with user-provided options
if options:
for key, value in options.items():
if hasattr(self, key):
setattr(self, key, value)
def to_dict(self) -> Dict[str, Any]:
"""
Convert configuration to a dictionary.
Returns:
Dictionary of configuration values
"""
return {
key: getattr(self, key)
for key in dir(self)
if not key.startswith("_") and not callable(getattr(self, key))
}
class APIUtils:
"""
Enhanced API utilities for MorphGuard with resilient request handling.
Features:
- Smart request queuing with priority handling
- Comprehensive error handling with detailed diagnostics
- Intelligent caching strategy with stale-while-revalidate support
- Circuit breaker pattern implementation
- Extended metrics and monitoring
- Context-aware request retry strategies
"""
def __init__(self, options: Dict[str, Any] = None):
"""
Initialize API utilities.
Args:
options: Configuration options
"""
# Load configuration
self.config = APIUtilsConfig(options)
# Initialize telemetry
self.telemetry = get_telemetry()
# Initialize circuit breakers (per endpoint)
self.circuit_breakers = {}
# Initialize cache
self.cache = StaleWhileRevalidateCache(
cache_dir=self.config.cache_dir,
max_entries=self.config.max_cache_entries,
stale_timeout=self.config.stale_timeout,
revalidation_timeout=self.config.revalidation_timeout,
background_refresh=self.config.background_refresh
)
# Initialize request queue for priority handling
if self.config.priority_queue_enabled:
self.request_queue = queue.PriorityQueue(maxsize=self.config.max_queue_size)
self.queue_thread = threading.Thread(
target=self._process_queue,
daemon=True
)
self.queue_thread.start()
# Create requests session
self.session = requests.Session()
# Initialize metrics
self.metrics = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"cached_responses": 0,
"stale_responses": 0,
"retries": 0,
"circuit_breaks": 0,
"total_time": 0,
"priority_requests": 0
}
# Thread lock for thread safety
self.lock = threading.RLock()
def _get_circuit_breaker(self, endpoint: str) -> CircuitBreaker:
"""
Get or create a circuit breaker for the endpoint.
Args:
endpoint: API endpoint
Returns:
CircuitBreaker instance
"""
with self.lock:
if endpoint not in self.circuit_breakers:
self.circuit_breakers[endpoint] = CircuitBreaker(
failure_threshold=self.config.failure_threshold,
recovery_timeout=self.config.recovery_timeout
)
return self.circuit_breakers[endpoint]
def _build_url(self, endpoint: str) -> str:
"""
Build a full URL from an endpoint.
Args:
endpoint: API endpoint
Returns:
Full URL
"""
# If endpoint is already a full URL, return it
if endpoint.startswith(('http://', 'https://')):
return endpoint
# Ensure endpoint starts with a slash
if not endpoint.startswith('/'):
endpoint = '/' + endpoint
return urljoin(self.config.base_url, endpoint)
def _generate_cache_key(self, method: str, endpoint: str, data: Any = None) -> str:
"""
Generate a cache key for a request.
Args:
method: HTTP method
endpoint: API endpoint
data: Request data
Returns:
Cache key
"""
# Create a string representation of the data for hashing
data_str = ""
if data:
if isinstance(data, dict):
# Sort keys for consistent hashing
data_str = json.dumps(data, sort_keys=True)
else:
data_str = str(data)
# Combine method, endpoint and data
combined = f"{method}:{endpoint}:{data_str}"
# Create a hash for the cache key
return hashlib.md5(combined.encode()).hexdigest()
def _calculate_priority(self, options: Dict[str, Any]) -> int:
"""
Calculate request priority based on options.
Args:
options: Request options
Returns:
Priority value (lower is higher priority)
"""
# Base priority - lower number means higher priority
base_priority = 100
# Adjust based on request options
if options.get("critical", False):
base_priority -= 50
if options.get("priority") == "high":
base_priority -= 30
elif options.get("priority") == "low":
base_priority += 30
# Cache requests have lower priority
if options.get("cache", True):
base_priority += 10
return base_priority
def _process_queue(self) -> None:
"""Process the request queue."""
while True:
try:
# Get the next request from the queue
_, request = self.request_queue.get()
# Extract request parameters
method = request["method"]
endpoint = request["endpoint"]
data = request["data"]
options = request["options"]
callback = request["callback"]
# Execute the request
try:
result = self._execute_request(method, endpoint, data, options)
success = True
except Exception as error:
result = error
success = False
# Call the callback with the result
try:
callback(success, result)
except Exception as error:
self.telemetry.error(
f"Error in request callback: {error}",
category=EventCategory.API,
context={
"method": method,
"endpoint": endpoint,
"error": str(error)
}
)
# Mark task as done
self.request_queue.task_done()
except Exception as error:
self.telemetry.error(
f"Error in request queue processor: {error}",
category=EventCategory.API,
exc_info=True
)
# Sleep to prevent busy waiting
time.sleep(0.01)
def _execute_request(
self,
method: str,
endpoint: str,
data: Any = None,
options: Dict[str, Any] = None
) -> Any:
"""
Execute a request with retries, caching, and circuit breaking.
Args:
method: HTTP method
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
Raises:
Exception: If the request fails
"""
# Default options
request_options = {
"timeout": self.config.timeout,
"retries": self.config.retries,
"retry_delay": self.config.retry_delay,
"cache": self.config.cache_enabled,
"circuit_breaker": self.config.circuit_breaker_enabled,
"critical": False,
"priority": "normal", # 'high', 'normal', 'low'
"fallback": None,
"headers": {}
}
# Update with provided options
if options:
request_options.update(options)
# Prepare request URL and headers
url = self._build_url(endpoint)
headers = {
"Content-Type": "application/json",
"Accept": "application/json"
}
headers.update(request_options["headers"])
# Generate cache key for GET requests
cache_key = None
if method == "GET" and request_options["cache"]:
cache_key = self._generate_cache_key(method, endpoint, data)
# Check cache
cached_data, is_fresh = self.cache.get(
cache_key,
lambda: self._make_request(method, url, data, headers, request_options)
)
if cached_data is not None:
# Update metrics
with self.lock:
if is_fresh:
self.metrics["cached_responses"] += 1
else:
self.metrics["stale_responses"] += 1
# Log telemetry
self.telemetry.debug(
f"Cache {is_fresh and 'hit' or 'stale hit'} for: {method} {endpoint}",
category=EventCategory.API,
context={
"method": method,
"endpoint": endpoint,
"cache_status": is_fresh and "fresh" or "stale"
}
)
return cached_data
# Check circuit breaker
if request_options["circuit_breaker"]:
circuit_breaker = self._get_circuit_breaker(endpoint)
if not circuit_breaker.allow_request():
# Update metrics
with self.lock:
self.metrics["circuit_breaks"] += 1
# Log telemetry
self.telemetry.warning(
f"Circuit breaker open for: {method} {endpoint}",
category=EventCategory.API,
context={
"method": method,
"endpoint": endpoint,
"failure_count": circuit_breaker.get_failure_count()
}
)
# Use fallback if provided
if request_options["fallback"] is not None:
if callable(request_options["fallback"]):
return request_options["fallback"]()
return request_options["fallback"]
# Raise error
error = APIError(
message=f"Circuit breaker open for {endpoint}",
code=ErrorCode.SERVICE_UNAVAILABLE,
retry_possible=False,
http_status=503
)
raise error
# Make the request
try:
# Track start time
start_time = time.time()
# Update metrics
with self.lock:
self.metrics["total_requests"] += 1
# Make the request with retries
response_data = self._make_request(
method, url, data, headers, request_options
)
# Update metrics
duration = time.time() - start_time
with self.lock:
self.metrics["successful_requests"] += 1
self.metrics["total_time"] += duration
# Record success for circuit breaker
if request_options["circuit_breaker"]:
circuit_breaker = self._get_circuit_breaker(endpoint)
circuit_breaker.record_success()
# Cache GET responses
if method == "GET" and request_options["cache"] and cache_key:
self.cache.put(cache_key, response_data)
# Log telemetry
self.telemetry.debug(
f"Request successful: {method} {endpoint}",
category=EventCategory.API,
context={
"method": method,
"endpoint": endpoint,
"duration_ms": int(duration * 1000)
}
)
return response_data
except Exception as error:
# Update metrics
with self.lock:
self.metrics["failed_requests"] += 1
# Record failure for circuit breaker
if request_options["circuit_breaker"]:
circuit_breaker = self._get_circuit_breaker(endpoint)
circuit_breaker.record_failure()
# Enhance error with context
if isinstance(error, MGError):
error.details["endpoint"] = endpoint
error.details["method"] = method
else:
# Convert to MGError
error = handle_exception(
error,
default_error_class=APIError,
default_message=f"Request failed: {method} {endpoint}"
)
# Log telemetry
self.telemetry.error(
f"Request failed: {method} {endpoint}",
category=EventCategory.API,
context={
"method": method,
"endpoint": endpoint,
"error": str(error),
"error_code": getattr(error, "code", None)
}
)
# Use fallback if provided
if request_options["fallback"] is not None:
if callable(request_options["fallback"]):
return request_options["fallback"]()
return request_options["fallback"]
# Re-raise error
raise error
def _make_request(
self,
method: str,
url: str,
data: Any,
headers: Dict[str, str],
options: Dict[str, Any]
) -> Any:
"""
Make a request with retries.
Args:
method: HTTP method
url: Request URL
data: Request data
headers: Request headers
options: Request options
Returns:
Response data
Raises:
Exception: If the request fails
"""
# Create request kwargs
request_kwargs = {
"method": method,
"url": url,
"headers": headers,
"timeout": options["timeout"]
}
# Add data based on method
if method == "GET" and data:
request_kwargs["params"] = data
elif method != "GET" and data is not None:
request_kwargs["json"] = data
# Retry loop
retries_left = options.get("retries", self.config.retries)
retry_delay = options.get("retry_delay", self.config.retry_delay)
while True:
try:
# Make the request
response = self.session.request(**request_kwargs)
# Check response status
if not response.ok:
# Handle 401 Unauthorized (token refresh would go here)
if response.status_code == 401:
# In a real implementation, we would refresh the token and retry
pass
# Parse error response
try:
error_data = response.json()
except:
error_data = {"message": response.reason}
# Check if this is a retryable status code
if response.status_code in self.config.retry_status_codes and retries_left > 0:
retries_left -= 1
# Calculate backoff delay
backoff_delay = self._calculate_backoff_delay(retry_delay, retries_left)
# Update metrics
with self.lock:
self.metrics["retries"] += 1
# Log telemetry
self.telemetry.debug(
f"Retrying request ({retries_left} attempts left): {method} {url}",
category=EventCategory.API,
context={
"method": method,
"url": url,
"status_code": response.status_code,
"retries_left": retries_left
}
)
# Wait before retrying
time.sleep(backoff_delay)
continue
# Create error based on status code
if response.status_code == 401:
error = APIError(
message=error_data.get("message", "Unauthorized"),
code=ErrorCode.UNAUTHORIZED,
details=error_data,
http_status=401
)
elif response.status_code == 403:
error = APIError(
message=error_data.get("message", "Forbidden"),
code=ErrorCode.FORBIDDEN,
details=error_data,
http_status=403
)
elif response.status_code == 404:
error = APIError(
message=error_data.get("message", "Not found"),
code=ErrorCode.NOT_FOUND,
details=error_data,
http_status=404
)
elif response.status_code == 429:
error = APIError(
message=error_data.get("message", "Rate limited"),
code=ErrorCode.RATE_LIMITED,
details=error_data,
retry_possible=True,
http_status=429
)
elif response.status_code >= 500:
error = APIError(
message=error_data.get("message", "Server error"),
code=ErrorCode.SERVICE_UNAVAILABLE,
details=error_data,
retry_possible=True,
http_status=response.status_code
)
else:
error = APIError(
message=error_data.get("message", "API error"),
code=ErrorCode.API_ERROR,
details=error_data,
http_status=response.status_code
)
raise error
# Parse response
content_type = response.headers.get("content-type", "")
if "application/json" in content_type:
return response.json()
elif any(text_type in content_type for text_type in ["text/", "application/xml"]):
return response.text
else:
# For binary data, return the raw content
return response.content
except (requests.exceptions.RequestException, APIError) as error:
# For connection errors, mark as network error
if isinstance(error, requests.exceptions.ConnectionError):
error = NetworkError(
message="Connection error",
details={"url": url},
original_error=error
)
# For timeout errors
elif isinstance(error, requests.exceptions.Timeout):
error = APIError(
message="Request timed out",
code=ErrorCode.TIMEOUT,
details={"url": url},
retry_possible=True,
original_error=error,
http_status=408
)
# Check if we should retry
if retries_left > 0 and (
isinstance(error, requests.exceptions.RequestException) or
getattr(error, "retry_possible", False)
):
retries_left -= 1
# Calculate backoff delay
backoff_delay = self._calculate_backoff_delay(retry_delay, retries_left)
# Update metrics
with self.lock:
self.metrics["retries"] += 1
# Log telemetry
self.telemetry.debug(
f"Retrying request ({retries_left} attempts left): {method} {url}",
category=EventCategory.API,
context={
"method": method,
"url": url,
"error": str(error),
"retries_left": retries_left
}
)
# Wait before retrying
time.sleep(backoff_delay)
continue
# No more retries, re-raise the error
raise
def _calculate_backoff_delay(self, base_delay: float, attempts_left: int) -> float:
"""
Calculate backoff delay for retries with jitter.
Args:
base_delay: Base delay in seconds
attempts_left: Number of attempts left
Returns:
Delay in seconds
"""
import random
# Calculate retry attempt number (1-based)
attempt = self.config.retries - attempts_left
# Exponential backoff: baseDelay * 2^(attempt-1)
exponential_delay = base_delay * (2 ** (attempt))
# Add jitter to prevent thundering herd problem (±25% randomness)
jitter = exponential_delay * 0.5 * (random.random() - 0.5)
# Cap at 30 seconds
return min(exponential_delay + jitter, 30.0)
def request(
self,
method: str,
endpoint: str,
data: Any = None,
options: Dict[str, Any] = None
) -> Any:
"""
Make a request with full resilience features.
Args:
method: HTTP method
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
Raises:
Exception: If the request fails
"""
# Process options
request_options = {
"timeout": self.config.timeout,
"retries": self.config.retries,
"retry_delay": self.config.retry_delay,
"cache": self.config.cache_enabled,
"circuit_breaker": self.config.circuit_breaker_enabled,
"critical": False,
"priority": "normal", # 'high', 'normal', 'low'
"fallback": None,
"headers": {}
}
if options:
request_options.update(options)
# Check if we should use the priority queue
if self.config.priority_queue_enabled and request_options.get("priority") != "normal":
# Use queue for priority handling
with self.lock:
self.metrics["priority_requests"] += 1
return self._queue_request(method, endpoint, data, request_options)
# Otherwise, execute directly
return self._execute_request(method, endpoint, data, request_options)
def _queue_request(
self,
method: str,
endpoint: str,
data: Any,
options: Dict[str, Any]
) -> Any:
"""
Queue a request for priority handling.
Args:
method: HTTP method
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
Raises:
Exception: If the request fails
"""
# Calculate priority
priority = self._calculate_priority(options)
# Create a result queue for this request
result_queue = queue.Queue(1)
# Callback function to put result in queue
def callback(success, result):
if success:
result_queue.put(result)
else:
result_queue.put(result) # Error object
# Add to priority queue
self.request_queue.put(
(priority, {
"method": method,
"endpoint": endpoint,
"data": data,
"options": options,
"callback": callback
})
)
# Wait for the result
try:
result = result_queue.get(timeout=options.get("timeout", self.config.timeout))
# If result is an exception, raise it
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
raise APIError(
message=f"Request timed out: {method} {endpoint}",
code=ErrorCode.TIMEOUT,
retry_possible=True,
http_status=408
)
def get(self, endpoint: str, data: Any = None, options: Dict[str, Any] = None) -> Any:
"""
Make a GET request.
Args:
endpoint: API endpoint
data: Query parameters
options: Request options
Returns:
Response data
"""
return self.request("GET", endpoint, data, options)
def post(self, endpoint: str, data: Any = None, options: Dict[str, Any] = None) -> Any:
"""
Make a POST request.
Args:
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
"""
return self.request("POST", endpoint, data, options)
def put(self, endpoint: str, data: Any = None, options: Dict[str, Any] = None) -> Any:
"""
Make a PUT request.
Args:
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
"""
return self.request("PUT", endpoint, data, options)
def patch(self, endpoint: str, data: Any = None, options: Dict[str, Any] = None) -> Any:
"""
Make a PATCH request.
Args:
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
"""
return self.request("PATCH", endpoint, data, options)
def delete(self, endpoint: str, data: Any = None, options: Dict[str, Any] = None) -> Any:
"""
Make a DELETE request.
Args:
endpoint: API endpoint
data: Request data
options: Request options
Returns:
Response data
"""
return self.request("DELETE", endpoint, data, options)
def get_metrics(self) -> Dict[str, Any]:
"""
Get request metrics.
Returns:
Metrics dictionary
"""
with self.lock:
return self.metrics.copy()
def reset_metrics(self) -> None:
"""Reset request metrics."""
with self.lock:
self.metrics = {
"total_requests": 0,
"successful_requests": 0,
"failed_requests": 0,
"cached_responses": 0,
"stale_responses": 0,
"retries": 0,
"circuit_breaks": 0,
"total_time": 0,
"priority_requests": 0
}
def clear_cache(self, endpoint: Optional[str] = None) -> None:
"""
Clear the cache.
Args:
endpoint: Optional endpoint to clear (clears all if None)
"""
if endpoint:
# Clear all keys matching the endpoint
for method in ["GET", "HEAD", "OPTIONS"]:
cache_key = self._generate_cache_key(method, endpoint)
self.cache.invalidate(cache_key)
else:
# Clear all cache
self.cache.clear()
def reset_circuit_breaker(self, endpoint: str) -> None:
"""
Reset the circuit breaker for an endpoint.
Args:
endpoint: API endpoint
"""
with self.lock:
if endpoint in self.circuit_breakers:
self.circuit_breakers[endpoint].reset()
def reset_all_circuit_breakers(self) -> None:
"""Reset all circuit breakers."""
with self.lock:
for circuit_breaker in self.circuit_breakers.values():
circuit_breaker.reset()
def get_circuit_breaker_status(self, endpoint: str) -> Dict[str, Any]:
"""
Get the status of a circuit breaker.
Args:
endpoint: API endpoint
Returns:
Circuit breaker status
"""
with self.lock:
if endpoint not in self.circuit_breakers:
return {
"state": "closed",
"failure_count": 0
}
circuit_breaker = self.circuit_breakers[endpoint]
return {
"state": circuit_breaker.get_state(),
"failure_count": circuit_breaker.get_failure_count()
}
def shutdown(self) -> None:
"""Shutdown the API utilities, closing resources."""
# Close request session
self.session.close()
# Shutdown cache
self.cache.shutdown()
# Reset metrics
self.reset_metrics()
# Singleton instance
_instance = None
def get_api_utils(options: Dict[str, Any] = None) -> APIUtils:
"""
Get the global API utilities instance.
Args:
options: Optional configuration options
Returns:
APIUtils instance
"""
global _instance
if _instance is None:
_instance = APIUtils(options)
return _instance
# Decorator for retrying functions
def retry(
retries: int = 3,
retry_delay: float = 1.0,
retryable_exceptions: List[Type[Exception]] = None,
fallback: Any = None
):
"""
Decorator for retrying functions.
Args:
retries: Number of retries
retry_delay: Delay between retries in seconds
retryable_exceptions: List of exceptions to retry on
fallback: Fallback value or function to use if all retries fail
Returns:
Decorator function
"""
if retryable_exceptions is None:
retryable_exceptions = [Exception]
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Get API utils for metrics and telemetry
api_utils = get_api_utils()
# Calculate function context for logging
func_name = func.__name__
module_name = func.__module__
# Retry loop
attempts_left = retries
while True:
try:
return func(*args, **kwargs)
except tuple(retryable_exceptions) as error:
# Check if we should retry
if attempts_left <= 0:
# No more retries, use fallback or re-raise
if fallback is not None:
if callable(fallback):
return fallback(*args, **kwargs)
return fallback
# Re-raise the error
raise
# Calculate retry delay with backoff
backoff_delay = api_utils._calculate_backoff_delay(
retry_delay, attempts_left
)
# Update metrics
with api_utils.lock:
api_utils.metrics["retries"] += 1
# Log telemetry
api_utils.telemetry.debug(
f"Retrying function {func_name} ({attempts_left} attempts left)",
category=EventCategory.API,
context={
"function": func_name,
"module": module_name,
"error": str(error),
"retries_left": attempts_left
}
)
# Wait before retrying
time.sleep(backoff_delay)
# Decrement attempts left
attempts_left -= 1
return wrapper
return decorator
# Decorator for circuit breaking
def circuit_breaker(
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
fallback: Any = None
):
"""
Decorator for applying circuit breaker pattern to functions.
Args:
failure_threshold: Number of failures before opening the circuit
recovery_timeout: Time in seconds before attempting recovery
fallback: Fallback value or function to use if circuit is open
Returns:
Decorator function
"""
def decorator(func):
# Create a circuit breaker for this function
cb = CircuitBreaker(
failure_threshold=failure_threshold,
recovery_timeout=recovery_timeout
)
@wraps(func)
def wrapper(*args, **kwargs):
# Get API utils for metrics and telemetry
api_utils = get_api_utils()
# Calculate function context for logging
func_name = func.__name__
module_name = func.__module__
# Check if circuit is open
if not cb.allow_request():
# Update metrics
with api_utils.lock:
api_utils.metrics["circuit_breaks"] += 1
# Log telemetry
api_utils.telemetry.warning(
f"Circuit breaker open for function {func_name}",
category=EventCategory.API,
context={
"function": func_name,
"module": module_name,
"failure_count": cb.get_failure_count()
}
)
# Use fallback if provided
if fallback is not None:
if callable(fallback):
return fallback(*args, **kwargs)
return fallback
# Raise error
error = APIError(
message=f"Circuit breaker open for {func_name}",
code=ErrorCode.SERVICE_UNAVAILABLE,
retry_possible=False,
http_status=503
)
raise error
try:
# Call the function
result = func(*args, **kwargs)
# Record success
cb.record_success()
return result
except Exception as error:
# Record failure
cb.record_failure()
# Re-raise the error
raise
# Add circuit breaker to function for external access
wrapper.circuit_breaker = cb
return wrapper
return decorator
# Decorator for caching function results
def cacheable(
ttl: float = 300.0, # 5 minutes
key_func: Optional[Callable] = None
):
"""
Decorator for caching function results.
Args:
ttl: Cache TTL in seconds
key_func: Optional function to generate cache key
Returns:
Decorator function
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Get API utils for cache
api_utils = get_api_utils()
# Generate cache key
if key_func:
cache_key = key_func(*args, **kwargs)
else:
# Default key generation
arg_key = ",".join(str(arg) for arg in args)
kwarg_key = ",".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
cache_key = f"{func.__module__}.{func.__name__}({arg_key},{kwarg_key})"
# Hash the key
hashed_key = hashlib.md5(cache_key.encode()).hexdigest()
# Check cache
cached_data, is_fresh = api_utils.cache.get(
hashed_key,
lambda: func(*args, **kwargs)
)
if cached_data is not None:
# Update metrics
with api_utils.lock:
if is_fresh:
api_utils.metrics["cached_responses"] += 1
else:
api_utils.metrics["stale_responses"] += 1
return cached_data
# Cache miss, call function
result = func(*args, **kwargs)
# Cache result
api_utils.cache.put(hashed_key, result, ttl)
return result
return wrapper
return decorator