MorphGuard / src /api_fallback_registry.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
34.6 kB
"""
Enhanced API fallback registry for MorphGuard.
This module provides a centralized registry for API fallbacks with:
- Context-aware fallback selection
- Tiered fallback strategies
- Dynamic fallback generation
- Health-aware routing
- Degraded mode operations
"""
import os
import json
import time
import logging
import threading
import random
from enum import Enum
from typing import Dict, Any, List, Optional, Union, Callable, Tuple, TypeVar
# Import error handling for proper error classification
from src.error_handling import MGError, ErrorCode, ErrorSeverity, ErrorCategory
from src.telemetry import get_telemetry, EventCategory
# Type definitions
T = TypeVar('T')
FallbackFunc = Callable[..., T]
FallbackValue = Any
class FallbackTier(Enum):
"""Tiers of fallbacks, ordered by preference."""
PRIMARY = 0 # First choice fallback
SECONDARY = 1 # Second choice fallback
TERTIARY = 2 # Third choice fallback
EMERGENCY = 3 # Last resort fallback
class FallbackStrategy(Enum):
"""Strategies for selecting fallbacks."""
STATIC = "static" # Always use the same fallback
RANDOM = "random" # Randomly select from available fallbacks
WEIGHTED = "weighted" # Select based on weights
LEAST_RECENTLY_USED = "lru" # Select least recently used fallback
ROUND_ROBIN = "round_robin" # Cycle through fallbacks
CONTEXT_BASED = "context" # Select based on request context
HEALTH_BASED = "health" # Select based on health checks
class HealthStatus(Enum):
"""Health status of services or endpoints."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
class FallbackMetadata:
"""Metadata for a fallback entry."""
def __init__(
self,
tier: FallbackTier = FallbackTier.PRIMARY,
weight: float = 1.0,
context_rules: Optional[Dict[str, Any]] = None,
last_used: float = 0,
use_count: int = 0,
success_count: int = 0,
failure_count: int = 0,
health_status: HealthStatus = HealthStatus.UNKNOWN,
last_health_check: float = 0,
tags: Optional[Dict[str, str]] = None
):
"""
Initialize fallback metadata.
Args:
tier: Fallback tier
weight: Weight for weighted selection (higher = more likely to be selected)
context_rules: Rules for context-based selection
last_used: Timestamp of last use
use_count: Number of times used
success_count: Number of successful uses
failure_count: Number of failed uses
health_status: Current health status
last_health_check: Timestamp of last health check
tags: Tags for categorizing fallbacks
"""
self.tier = tier
self.weight = weight
self.context_rules = context_rules or {}
self.last_used = last_used
self.use_count = use_count
self.success_count = success_count
self.failure_count = failure_count
self.health_status = health_status
self.last_health_check = last_health_check
self.tags = tags or {}
def update_usage(self, success: bool = True) -> None:
"""
Update usage statistics.
Args:
success: Whether the fallback was successful
"""
self.last_used = time.time()
self.use_count += 1
if success:
self.success_count += 1
else:
self.failure_count += 1
def update_health(self, status: HealthStatus) -> None:
"""
Update health status.
Args:
status: New health status
"""
self.health_status = status
self.last_health_check = time.time()
def success_rate(self) -> float:
"""
Calculate success rate.
Returns:
Success rate (0-1) or 0 if never used
"""
if self.use_count == 0:
return 0
return self.success_count / self.use_count
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary.
Returns:
Dictionary representation
"""
return {
"tier": self.tier.name,
"weight": self.weight,
"context_rules": self.context_rules,
"last_used": self.last_used,
"use_count": self.use_count,
"success_count": self.success_count,
"failure_count": self.failure_count,
"health_status": self.health_status.value,
"last_health_check": self.last_health_check,
"tags": self.tags,
"success_rate": self.success_rate()
}
class FallbackEntry:
"""A fallback entry in the registry."""
def __init__(
self,
key: str,
value: Union[Any, FallbackFunc],
is_function: bool = False,
metadata: Optional[FallbackMetadata] = None
):
"""
Initialize a fallback entry.
Args:
key: Unique identifier for the fallback
value: Fallback value or function
is_function: Whether the value is a function
metadata: Fallback metadata
"""
self.key = key
self.value = value
self.is_function = is_function
self.metadata = metadata or FallbackMetadata()
def get_value(self, *args, **kwargs) -> Any:
"""
Get the fallback value.
Args:
*args: Arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
Fallback value
"""
try:
if self.is_function:
result = self.value(*args, **kwargs)
self.metadata.update_usage(success=True)
return result
else:
self.metadata.update_usage(success=True)
return self.value
except Exception as e:
self.metadata.update_usage(success=False)
raise e
def to_dict(self) -> Dict[str, Any]:
"""
Convert to dictionary.
Returns:
Dictionary representation
"""
return {
"key": self.key,
"is_function": self.is_function,
"metadata": self.metadata.to_dict(),
"value_type": type(self.value).__name__ if not self.is_function else "function"
}
class APIFallbackRegistry:
"""
Registry for API fallbacks with intelligent selection.
This registry provides:
- Endpoint-specific fallbacks
- Tiered fallback selection
- Multiple selection strategies
- Health-aware routing
- Degraded mode operations
"""
def __init__(self, options: Dict[str, Any] = None):
"""
Initialize the fallback registry.
Args:
options: Configuration options
"""
# Default configuration
self.config = {
"default_strategy": FallbackStrategy.STATIC,
"health_check_interval": 60, # seconds
"auto_health_checks": True,
"fallback_storage_path": None,
"load_from_storage": False,
"save_to_storage": False,
"context_rules_path": None,
"degraded_mode_threshold": 0.5, # 50% failure rate
"enable_telemetry": True
}
# Update with user-provided options
if options:
self.config.update(options)
# Initialize telemetry
self.telemetry = get_telemetry()
# Fallback entries by endpoint
self.fallbacks: Dict[str, List[FallbackEntry]] = {}
# Current selection indexes for round-robin
self.round_robin_indexes: Dict[str, int] = {}
# System health status
self.system_health = HealthStatus.HEALTHY
# Health check thread
self.health_check_thread = None
# Degraded mode flag
self.degraded_mode = False
# Locks
self.fallbacks_lock = threading.RLock()
self.health_lock = threading.RLock()
# Load fallbacks from storage if enabled
if self.config["load_from_storage"] and self.config["fallback_storage_path"]:
self._load_from_storage()
# Start health check thread if enabled
if self.config["auto_health_checks"]:
self._start_health_check_thread()
def _start_health_check_thread(self) -> None:
"""Start the background health check thread."""
self.health_check_thread = threading.Thread(
target=self._health_check_worker,
daemon=True,
name="fallback_health_checker"
)
self.health_check_thread.start()
def _health_check_worker(self) -> None:
"""Background worker to check health of fallbacks."""
while True:
try:
# Wait for the check interval
time.sleep(self.config["health_check_interval"])
# Run health checks
self._check_all_health()
# Update degraded mode flag
self._update_degraded_mode()
# Save to storage if enabled
if self.config["save_to_storage"] and self.config["fallback_storage_path"]:
self._save_to_storage()
except Exception as e:
self.telemetry.error(
f"Error in fallback health check worker: {e}",
category=EventCategory.API,
exc_info=True
)
def _check_all_health(self) -> None:
"""Check health of all fallbacks."""
with self.fallbacks_lock:
for endpoint, entries in self.fallbacks.items():
for entry in entries:
if entry.is_function:
# For function fallbacks, check if they're callable
try:
# Simple health check - just make sure the function is callable
if callable(entry.value):
entry.metadata.update_health(HealthStatus.HEALTHY)
else:
entry.metadata.update_health(HealthStatus.UNHEALTHY)
except Exception:
entry.metadata.update_health(HealthStatus.UNHEALTHY)
def _update_degraded_mode(self) -> None:
"""Update the degraded mode flag based on health checks."""
with self.health_lock:
# Count unhealthy endpoints
total_endpoints = 0
unhealthy_endpoints = 0
with self.fallbacks_lock:
for endpoint, entries in self.fallbacks.items():
total_endpoints += 1
# Check if all fallbacks for this endpoint are unhealthy
all_unhealthy = all(
entry.metadata.health_status == HealthStatus.UNHEALTHY
for entry in entries
)
if all_unhealthy:
unhealthy_endpoints += 1
# Calculate percentage of unhealthy endpoints
if total_endpoints > 0:
unhealthy_ratio = unhealthy_endpoints / total_endpoints
# Update degraded mode flag
old_degraded_mode = self.degraded_mode
self.degraded_mode = unhealthy_ratio >= self.config["degraded_mode_threshold"]
# Log degraded mode changes
if self.degraded_mode != old_degraded_mode:
if self.degraded_mode:
self.telemetry.warning(
"System entered degraded mode",
category=EventCategory.API,
context={
"unhealthy_ratio": unhealthy_ratio,
"threshold": self.config["degraded_mode_threshold"]
}
)
else:
self.telemetry.info(
"System exited degraded mode",
category=EventCategory.API,
context={
"unhealthy_ratio": unhealthy_ratio,
"threshold": self.config["degraded_mode_threshold"]
}
)
def _load_from_storage(self) -> None:
"""Load fallbacks from storage."""
try:
if not os.path.exists(self.config["fallback_storage_path"]):
return
with open(self.config["fallback_storage_path"], "r") as f:
data = json.load(f)
with self.fallbacks_lock:
for endpoint, entries in data.items():
self.fallbacks[endpoint] = []
for entry_data in entries:
# Function fallbacks can't be serialized, so only load static values
if not entry_data.get("is_function", False):
metadata = FallbackMetadata(
tier=FallbackTier[entry_data["metadata"]["tier"]],
weight=entry_data["metadata"]["weight"],
context_rules=entry_data["metadata"]["context_rules"],
last_used=entry_data["metadata"]["last_used"],
use_count=entry_data["metadata"]["use_count"],
success_count=entry_data["metadata"]["success_count"],
failure_count=entry_data["metadata"]["failure_count"],
health_status=HealthStatus(entry_data["metadata"]["health_status"]),
last_health_check=entry_data["metadata"]["last_health_check"],
tags=entry_data["metadata"]["tags"]
)
entry = FallbackEntry(
key=entry_data["key"],
value=entry_data["value"],
is_function=False,
metadata=metadata
)
self.fallbacks[endpoint].append(entry)
self.telemetry.info(
"Loaded fallbacks from storage",
category=EventCategory.API,
context={"path": self.config["fallback_storage_path"]}
)
except Exception as e:
self.telemetry.error(
f"Failed to load fallbacks from storage: {e}",
category=EventCategory.API,
exc_info=True
)
def _save_to_storage(self) -> None:
"""Save fallbacks to storage."""
try:
if not self.config["fallback_storage_path"]:
return
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.config["fallback_storage_path"]), exist_ok=True)
data = {}
with self.fallbacks_lock:
for endpoint, entries in self.fallbacks.items():
data[endpoint] = []
for entry in entries:
# Skip function fallbacks as they can't be serialized
if entry.is_function:
continue
entry_data = entry.to_dict()
entry_data["value"] = entry.value # Add actual value
data[endpoint].append(entry_data)
with open(self.config["fallback_storage_path"], "w") as f:
json.dump(data, f, indent=2)
self.telemetry.debug(
"Saved fallbacks to storage",
category=EventCategory.API,
context={"path": self.config["fallback_storage_path"]}
)
except Exception as e:
self.telemetry.error(
f"Failed to save fallbacks to storage: {e}",
category=EventCategory.API,
exc_info=True
)
def register_fallback(
self,
endpoint: str,
key: str,
fallback: Union[Any, FallbackFunc],
is_function: bool = False,
tier: FallbackTier = FallbackTier.PRIMARY,
weight: float = 1.0,
context_rules: Optional[Dict[str, Any]] = None,
tags: Optional[Dict[str, str]] = None
) -> None:
"""
Register a fallback for an endpoint.
Args:
endpoint: API endpoint pattern (can use wildcards like * for pattern matching)
key: Unique identifier for the fallback
fallback: Fallback value or function
is_function: Whether the fallback is a function
tier: Fallback tier
weight: Weight for weighted selection
context_rules: Rules for context-based selection
tags: Tags for categorizing fallbacks
"""
with self.fallbacks_lock:
# Create entry list if it doesn't exist
if endpoint not in self.fallbacks:
self.fallbacks[endpoint] = []
# Check if entry with same key already exists
for i, entry in enumerate(self.fallbacks[endpoint]):
if entry.key == key:
# Replace existing entry
metadata = FallbackMetadata(
tier=tier,
weight=weight,
context_rules=context_rules,
last_used=entry.metadata.last_used,
use_count=entry.metadata.use_count,
success_count=entry.metadata.success_count,
failure_count=entry.metadata.failure_count,
health_status=entry.metadata.health_status,
last_health_check=entry.metadata.last_health_check,
tags=tags or entry.metadata.tags
)
self.fallbacks[endpoint][i] = FallbackEntry(
key=key,
value=fallback,
is_function=is_function,
metadata=metadata
)
self.telemetry.debug(
f"Updated fallback {key} for endpoint {endpoint}",
category=EventCategory.API
)
return
# Add new entry
metadata = FallbackMetadata(
tier=tier,
weight=weight,
context_rules=context_rules,
tags=tags
)
self.fallbacks[endpoint].append(FallbackEntry(
key=key,
value=fallback,
is_function=is_function,
metadata=metadata
))
self.telemetry.debug(
f"Registered fallback {key} for endpoint {endpoint}",
category=EventCategory.API
)
def unregister_fallback(self, endpoint: str, key: str) -> bool:
"""
Unregister a fallback.
Args:
endpoint: API endpoint
key: Fallback key
Returns:
Whether the fallback was found and unregistered
"""
with self.fallbacks_lock:
if endpoint not in self.fallbacks:
return False
# Find entry with matching key
for i, entry in enumerate(self.fallbacks[endpoint]):
if entry.key == key:
# Remove entry
self.fallbacks[endpoint].pop(i)
# Remove endpoint if empty
if not self.fallbacks[endpoint]:
del self.fallbacks[endpoint]
self.telemetry.debug(
f"Unregistered fallback {key} for endpoint {endpoint}",
category=EventCategory.API
)
return True
return False
def get_fallback(
self,
endpoint: str,
strategy: Optional[FallbackStrategy] = None,
context: Optional[Dict[str, Any]] = None,
args: Optional[List[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None
) -> Optional[Any]:
"""
Get a fallback value for an endpoint.
Args:
endpoint: API endpoint
strategy: Selection strategy
context: Request context for context-based selection
args: Arguments to pass to function fallbacks
kwargs: Keyword arguments to pass to function fallbacks
Returns:
Fallback value or None if no fallback is available
"""
with self.fallbacks_lock:
# Find matching endpoint (support wildcard patterns)
matching_endpoints = []
for pattern in self.fallbacks.keys():
# Convert pattern to regex
regex_pattern = pattern.replace("*", ".*")
# Check if endpoint matches pattern
if endpoint == pattern or (
"*" in pattern and
(endpoint.startswith(pattern.replace("*", "")) or
endpoint.endswith(pattern.replace("*", "")) or
endpoint.replace("/", "") == pattern.replace("*/", "").replace("/*", "").replace("*", ""))
):
matching_endpoints.append(pattern)
if not matching_endpoints:
self.telemetry.debug(
f"No fallbacks found for endpoint {endpoint}",
category=EventCategory.API
)
return None
# Get entries for all matching endpoints
all_entries = []
for pattern in matching_endpoints:
all_entries.extend(self.fallbacks[pattern])
if not all_entries:
return None
# Use provided strategy or default
if strategy is None:
strategy = FallbackStrategy(self.config["default_strategy"])
# Select fallback using the specified strategy
selected_entry = self._select_fallback(
endpoint, all_entries, strategy, context
)
if selected_entry is None:
return None
# Get the fallback value
args = args or []
kwargs = kwargs or {}
try:
result = selected_entry.get_value(*args, **kwargs)
self.telemetry.debug(
f"Used fallback {selected_entry.key} for endpoint {endpoint}",
category=EventCategory.API,
context={"strategy": strategy.value}
)
return result
except Exception as e:
self.telemetry.error(
f"Error using fallback {selected_entry.key} for endpoint {endpoint}: {e}",
category=EventCategory.API,
exc_info=True
)
# Try another fallback
remaining_entries = [
entry for entry in all_entries
if entry.key != selected_entry.key
]
if remaining_entries:
# Select another fallback
alternative_entry = self._select_fallback(
endpoint, remaining_entries, strategy, context
)
if alternative_entry is not None:
try:
result = alternative_entry.get_value(*args, **kwargs)
self.telemetry.debug(
f"Used alternative fallback {alternative_entry.key} for endpoint {endpoint}",
category=EventCategory.API,
context={"strategy": strategy.value}
)
return result
except Exception:
pass
return None
def _select_fallback(
self,
endpoint: str,
entries: List[FallbackEntry],
strategy: FallbackStrategy,
context: Optional[Dict[str, Any]] = None
) -> Optional[FallbackEntry]:
"""
Select a fallback using the specified strategy.
Args:
endpoint: API endpoint
entries: Available fallback entries
strategy: Selection strategy
context: Request context for context-based selection
Returns:
Selected fallback entry or None if no fallback is available
"""
if not entries:
return None
# Filter by health status if using health-based strategy
if strategy == FallbackStrategy.HEALTH_BASED:
healthy_entries = [
entry for entry in entries
if entry.metadata.health_status == HealthStatus.HEALTHY
]
if healthy_entries:
entries = healthy_entries
else:
# Fall back to degraded entries if no healthy ones
degraded_entries = [
entry for entry in entries
if entry.metadata.health_status == HealthStatus.DEGRADED
]
if degraded_entries:
entries = degraded_entries
# Group by tier
tier_groups: Dict[FallbackTier, List[FallbackEntry]] = {}
for entry in entries:
tier = entry.metadata.tier
if tier not in tier_groups:
tier_groups[tier] = []
tier_groups[tier].append(entry)
# Sort tiers by priority
sorted_tiers = sorted(tier_groups.keys())
# Select entries from highest priority tier
if sorted_tiers:
highest_tier = sorted_tiers[0]
tier_entries = tier_groups[highest_tier]
if len(tier_entries) == 1:
return tier_entries[0]
# Apply strategy within the tier
if strategy == FallbackStrategy.STATIC:
# Always use the first entry
return tier_entries[0]
elif strategy == FallbackStrategy.RANDOM:
# Randomly select an entry
return random.choice(tier_entries)
elif strategy == FallbackStrategy.WEIGHTED:
# Select based on weights
weights = [entry.metadata.weight for entry in tier_entries]
return random.choices(tier_entries, weights=weights, k=1)[0]
elif strategy == FallbackStrategy.LEAST_RECENTLY_USED:
# Select least recently used entry
return min(tier_entries, key=lambda e: e.metadata.last_used)
elif strategy == FallbackStrategy.ROUND_ROBIN:
# Cycle through entries
if endpoint not in self.round_robin_indexes:
self.round_robin_indexes[endpoint] = 0
index = self.round_robin_indexes[endpoint]
selected = tier_entries[index % len(tier_entries)]
# Update index for next time
self.round_robin_indexes[endpoint] = (index + 1) % len(tier_entries)
return selected
elif strategy == FallbackStrategy.CONTEXT_BASED:
# Select based on context rules
if context:
for entry in tier_entries:
if self._matches_context(entry, context):
return entry
# Fall back to first entry if no match
return tier_entries[0]
else:
# Unknown strategy, use first entry
return tier_entries[0]
return None
def _matches_context(self, entry: FallbackEntry, context: Dict[str, Any]) -> bool:
"""
Check if an entry matches the given context.
Args:
entry: Fallback entry
context: Request context
Returns:
Whether the entry matches the context
"""
if not entry.metadata.context_rules:
return False
# Check all rules
for key, rule_value in entry.metadata.context_rules.items():
if key not in context:
return False
context_value = context[key]
# Check if rule value is a condition
if isinstance(rule_value, dict) and "operator" in rule_value:
operator = rule_value["operator"]
value = rule_value["value"]
if operator == "eq" and context_value != value:
return False
elif operator == "ne" and context_value == value:
return False
elif operator == "gt" and not (isinstance(context_value, (int, float)) and context_value > value):
return False
elif operator == "lt" and not (isinstance(context_value, (int, float)) and context_value < value):
return False
elif operator == "gte" and not (isinstance(context_value, (int, float)) and context_value >= value):
return False
elif operator == "lte" and not (isinstance(context_value, (int, float)) and context_value <= value):
return False
elif operator == "contains" and not (
(isinstance(context_value, str) and value in context_value) or
(isinstance(context_value, (list, tuple)) and value in context_value)
):
return False
elif operator == "in" and context_value not in value:
return False
# Simple equality check
elif context_value != rule_value:
return False
return True
def is_degraded_mode(self) -> bool:
"""
Check if the system is in degraded mode.
Returns:
Whether the system is in degraded mode
"""
return self.degraded_mode
def get_fallbacks_for_endpoint(self, endpoint: str) -> List[Dict[str, Any]]:
"""
Get all fallbacks for an endpoint.
Args:
endpoint: API endpoint
Returns:
List of fallback entries as dictionaries
"""
with self.fallbacks_lock:
if endpoint not in self.fallbacks:
return []
return [entry.to_dict() for entry in self.fallbacks[endpoint]]
def get_all_fallbacks(self) -> Dict[str, List[Dict[str, Any]]]:
"""
Get all registered fallbacks.
Returns:
Dictionary of fallbacks by endpoint
"""
result = {}
with self.fallbacks_lock:
for endpoint, entries in self.fallbacks.items():
result[endpoint] = [entry.to_dict() for entry in entries]
return result
def clear_fallbacks(self, endpoint: Optional[str] = None) -> None:
"""
Clear fallbacks.
Args:
endpoint: Optional endpoint to clear (clears all if None)
"""
with self.fallbacks_lock:
if endpoint:
if endpoint in self.fallbacks:
del self.fallbacks[endpoint]
self.telemetry.debug(
f"Cleared fallbacks for endpoint {endpoint}",
category=EventCategory.API
)
else:
self.fallbacks.clear()
self.telemetry.debug(
"Cleared all fallbacks",
category=EventCategory.API
)
def shutdown(self) -> None:
"""Shutdown the registry and save to storage if enabled."""
# Save to storage if enabled
if self.config["save_to_storage"] and self.config["fallback_storage_path"]:
self._save_to_storage()
# Singleton instance
_instance = None
def get_fallback_registry(options: Dict[str, Any] = None) -> APIFallbackRegistry:
"""
Get the global fallback registry instance.
Args:
options: Configuration options
Returns:
APIFallbackRegistry instance
"""
global _instance
if _instance is None:
_instance = APIFallbackRegistry(options)
return _instance