llm-api-proxy / src /rotator_library /error_handler.py
Mirrowel
feat: enhance changelog generation and error handling in build process
5db0197
raw
history blame
6.08 kB
import re
from typing import Optional, Dict, Any
from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError
class ClassifiedError:
"""A structured representation of a classified error."""
def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
self.error_type = error_type
self.original_exception = original_exception
self.status_code = status_code
self.retry_after = retry_after
def __str__(self):
return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
import json
def get_retry_after(error: Exception) -> Optional[int]:
"""
Extracts the 'retry-after' duration in seconds from an exception message.
Handles both integer and string representations of the duration, as well as JSON bodies.
"""
error_str = str(error).lower()
# 1. Try to parse JSON from the error string to find 'retryDelay'
try:
# It's common for the actual JSON to be embedded in the string representation
json_match = re.search(r'(\{.*\})', error_str)
if json_match:
error_json = json.loads(json_match.group(1))
retry_info = error_json.get('error', {}).get('details', [{}])[0]
if retry_info.get('@type') == 'type.googleapis.com/google.rpc.RetryInfo':
delay_str = retry_info.get('retryDelay', {}).get('seconds')
if delay_str:
return int(delay_str)
# Fallback for the other format
delay_str = retry_info.get('retryDelay')
if isinstance(delay_str, str) and delay_str.endswith('s'):
return int(delay_str[:-1])
except (json.JSONDecodeError, IndexError, KeyError, TypeError):
pass # If JSON parsing fails, proceed to regex and attribute checks
# 2. Common regex patterns for 'retry-after'
patterns = [
r'retry after:?\s*(\d+)',
r'retry_after:?\s*(\d+)',
r'retry in\s*(\d+)\s*seconds',
r'wait for\s*(\d+)\s*seconds',
r'"retryDelay":\s*"(\d+)s"',
]
for pattern in patterns:
match = re.search(pattern, error_str)
if match:
try:
return int(match.group(1))
except (ValueError, IndexError):
continue
# 3. Handle cases where the error object itself has the attribute
if hasattr(error, 'retry_after'):
value = getattr(error, 'retry_after')
if isinstance(value, int):
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None
def classify_error(e: Exception) -> ClassifiedError:
"""
Classifies an exception into a structured ClassifiedError object.
"""
status_code = getattr(e, 'status_code', None)
if isinstance(e, RateLimitError):
retry_after = get_retry_after(e)
return ClassifiedError(
error_type='rate_limit',
original_exception=e,
status_code=status_code or 429,
retry_after=retry_after
)
if isinstance(e, (AuthenticationError,)):
return ClassifiedError(
error_type='authentication',
original_exception=e,
status_code=status_code or 401
)
if isinstance(e, (InvalidRequestError, BadRequestError)):
return ClassifiedError(
error_type='invalid_request',
original_exception=e,
status_code=status_code or 400
)
if isinstance(e, (ServiceUnavailableError, APIConnectionError, OpenAIError, InternalServerError)):
# These are often temporary server-side issues
return ClassifiedError(
error_type='server_error',
original_exception=e,
status_code=status_code or 503
)
# Fallback for any other unclassified errors
return ClassifiedError(
error_type='unknown',
original_exception=e,
status_code=status_code
)
def is_rate_limit_error(e: Exception) -> bool:
"""Checks if the exception is a rate limit error."""
return isinstance(e, RateLimitError)
def is_server_error(e: Exception) -> bool:
"""Checks if the exception is a temporary server-side error."""
return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError))
def is_unrecoverable_error(e: Exception) -> bool:
"""
Checks if the exception is a non-retriable client-side error.
These are errors that will not resolve on their own.
"""
return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError))
class AllProviders:
"""
A class to handle provider-specific settings, such as custom API bases.
"""
def __init__(self):
self.providers = {
"chutes": {
"api_base": "https://llm.chutes.ai/v1",
"model_prefix": "openai/"
}
}
def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]:
"""
Returns provider-specific kwargs for a given model.
"""
model = kwargs.get("model")
if not model:
return kwargs
provider = self._get_provider_from_model(model)
provider_settings = self.providers.get(provider, {})
if "api_base" in provider_settings:
kwargs["api_base"] = provider_settings["api_base"]
if "model_prefix" in provider_settings:
kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}"
return kwargs
def _get_provider_from_model(self, model: str) -> str:
"""
Determines the provider from the model name.
"""
return model.split('/')[0]