Spaces:
Paused
Paused
File size: 6,079 Bytes
27b342a 5db0197 27b342a 5db0197 27b342a 5db0197 27b342a 5db0197 27b342a 5db0197 27b342a bf565ba 27b342a bf565ba 4bbfff4 bf565ba 27b342a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import re
from typing import Optional, Dict, Any
from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError
class ClassifiedError:
"""A structured representation of a classified error."""
def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
self.error_type = error_type
self.original_exception = original_exception
self.status_code = status_code
self.retry_after = retry_after
def __str__(self):
return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
import json
def get_retry_after(error: Exception) -> Optional[int]:
"""
Extracts the 'retry-after' duration in seconds from an exception message.
Handles both integer and string representations of the duration, as well as JSON bodies.
"""
error_str = str(error).lower()
# 1. Try to parse JSON from the error string to find 'retryDelay'
try:
# It's common for the actual JSON to be embedded in the string representation
json_match = re.search(r'(\{.*\})', error_str)
if json_match:
error_json = json.loads(json_match.group(1))
retry_info = error_json.get('error', {}).get('details', [{}])[0]
if retry_info.get('@type') == 'type.googleapis.com/google.rpc.RetryInfo':
delay_str = retry_info.get('retryDelay', {}).get('seconds')
if delay_str:
return int(delay_str)
# Fallback for the other format
delay_str = retry_info.get('retryDelay')
if isinstance(delay_str, str) and delay_str.endswith('s'):
return int(delay_str[:-1])
except (json.JSONDecodeError, IndexError, KeyError, TypeError):
pass # If JSON parsing fails, proceed to regex and attribute checks
# 2. Common regex patterns for 'retry-after'
patterns = [
r'retry after:?\s*(\d+)',
r'retry_after:?\s*(\d+)',
r'retry in\s*(\d+)\s*seconds',
r'wait for\s*(\d+)\s*seconds',
r'"retryDelay":\s*"(\d+)s"',
]
for pattern in patterns:
match = re.search(pattern, error_str)
if match:
try:
return int(match.group(1))
except (ValueError, IndexError):
continue
# 3. Handle cases where the error object itself has the attribute
if hasattr(error, 'retry_after'):
value = getattr(error, 'retry_after')
if isinstance(value, int):
return value
if isinstance(value, str) and value.isdigit():
return int(value)
return None
def classify_error(e: Exception) -> ClassifiedError:
"""
Classifies an exception into a structured ClassifiedError object.
"""
status_code = getattr(e, 'status_code', None)
if isinstance(e, RateLimitError):
retry_after = get_retry_after(e)
return ClassifiedError(
error_type='rate_limit',
original_exception=e,
status_code=status_code or 429,
retry_after=retry_after
)
if isinstance(e, (AuthenticationError,)):
return ClassifiedError(
error_type='authentication',
original_exception=e,
status_code=status_code or 401
)
if isinstance(e, (InvalidRequestError, BadRequestError)):
return ClassifiedError(
error_type='invalid_request',
original_exception=e,
status_code=status_code or 400
)
if isinstance(e, (ServiceUnavailableError, APIConnectionError, OpenAIError, InternalServerError)):
# These are often temporary server-side issues
return ClassifiedError(
error_type='server_error',
original_exception=e,
status_code=status_code or 503
)
# Fallback for any other unclassified errors
return ClassifiedError(
error_type='unknown',
original_exception=e,
status_code=status_code
)
def is_rate_limit_error(e: Exception) -> bool:
"""Checks if the exception is a rate limit error."""
return isinstance(e, RateLimitError)
def is_server_error(e: Exception) -> bool:
"""Checks if the exception is a temporary server-side error."""
return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError))
def is_unrecoverable_error(e: Exception) -> bool:
"""
Checks if the exception is a non-retriable client-side error.
These are errors that will not resolve on their own.
"""
return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError))
class AllProviders:
"""
A class to handle provider-specific settings, such as custom API bases.
"""
def __init__(self):
self.providers = {
"chutes": {
"api_base": "https://llm.chutes.ai/v1",
"model_prefix": "openai/"
}
}
def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]:
"""
Returns provider-specific kwargs for a given model.
"""
model = kwargs.get("model")
if not model:
return kwargs
provider = self._get_provider_from_model(model)
provider_settings = self.providers.get(provider, {})
if "api_base" in provider_settings:
kwargs["api_base"] = provider_settings["api_base"]
if "model_prefix" in provider_settings:
kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}"
return kwargs
def _get_provider_from_model(self, model: str) -> str:
"""
Determines the provider from the model name.
"""
return model.split('/')[0]
|