Spaces:
Paused
Paused
| import re | |
| from typing import Optional, Dict, Any | |
| from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError | |
| class ClassifiedError: | |
| """A structured representation of a classified error.""" | |
| def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None): | |
| self.error_type = error_type | |
| self.original_exception = original_exception | |
| self.status_code = status_code | |
| self.retry_after = retry_after | |
| def __str__(self): | |
| return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})" | |
| import json | |
| def get_retry_after(error: Exception) -> Optional[int]: | |
| """ | |
| Extracts the 'retry-after' duration in seconds from an exception message. | |
| Handles both integer and string representations of the duration, as well as JSON bodies. | |
| """ | |
| error_str = str(error).lower() | |
| # 1. Try to parse JSON from the error string to find 'retryDelay' | |
| try: | |
| # It's common for the actual JSON to be embedded in the string representation | |
| json_match = re.search(r'(\{.*\})', error_str) | |
| if json_match: | |
| error_json = json.loads(json_match.group(1)) | |
| retry_info = error_json.get('error', {}).get('details', [{}])[0] | |
| if retry_info.get('@type') == 'type.googleapis.com/google.rpc.RetryInfo': | |
| delay_str = retry_info.get('retryDelay', {}).get('seconds') | |
| if delay_str: | |
| return int(delay_str) | |
| # Fallback for the other format | |
| delay_str = retry_info.get('retryDelay') | |
| if isinstance(delay_str, str) and delay_str.endswith('s'): | |
| return int(delay_str[:-1]) | |
| except (json.JSONDecodeError, IndexError, KeyError, TypeError): | |
| pass # If JSON parsing fails, proceed to regex and attribute checks | |
| # 2. Common regex patterns for 'retry-after' | |
| patterns = [ | |
| r'retry after:?\s*(\d+)', | |
| r'retry_after:?\s*(\d+)', | |
| r'retry in\s*(\d+)\s*seconds', | |
| r'wait for\s*(\d+)\s*seconds', | |
| r'"retryDelay":\s*"(\d+)s"', | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, error_str) | |
| if match: | |
| try: | |
| return int(match.group(1)) | |
| except (ValueError, IndexError): | |
| continue | |
| # 3. Handle cases where the error object itself has the attribute | |
| if hasattr(error, 'retry_after'): | |
| value = getattr(error, 'retry_after') | |
| if isinstance(value, int): | |
| return value | |
| if isinstance(value, str) and value.isdigit(): | |
| return int(value) | |
| return None | |
| def classify_error(e: Exception) -> ClassifiedError: | |
| """ | |
| Classifies an exception into a structured ClassifiedError object. | |
| """ | |
| status_code = getattr(e, 'status_code', None) | |
| if isinstance(e, RateLimitError): | |
| retry_after = get_retry_after(e) | |
| return ClassifiedError( | |
| error_type='rate_limit', | |
| original_exception=e, | |
| status_code=status_code or 429, | |
| retry_after=retry_after | |
| ) | |
| if isinstance(e, (AuthenticationError,)): | |
| return ClassifiedError( | |
| error_type='authentication', | |
| original_exception=e, | |
| status_code=status_code or 401 | |
| ) | |
| if isinstance(e, (InvalidRequestError, BadRequestError)): | |
| return ClassifiedError( | |
| error_type='invalid_request', | |
| original_exception=e, | |
| status_code=status_code or 400 | |
| ) | |
| if isinstance(e, (ServiceUnavailableError, APIConnectionError, OpenAIError, InternalServerError)): | |
| # These are often temporary server-side issues | |
| return ClassifiedError( | |
| error_type='server_error', | |
| original_exception=e, | |
| status_code=status_code or 503 | |
| ) | |
| # Fallback for any other unclassified errors | |
| return ClassifiedError( | |
| error_type='unknown', | |
| original_exception=e, | |
| status_code=status_code | |
| ) | |
| def is_rate_limit_error(e: Exception) -> bool: | |
| """Checks if the exception is a rate limit error.""" | |
| return isinstance(e, RateLimitError) | |
| def is_server_error(e: Exception) -> bool: | |
| """Checks if the exception is a temporary server-side error.""" | |
| return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError)) | |
| def is_unrecoverable_error(e: Exception) -> bool: | |
| """ | |
| Checks if the exception is a non-retriable client-side error. | |
| These are errors that will not resolve on their own. | |
| """ | |
| return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError)) | |
| class AllProviders: | |
| """ | |
| A class to handle provider-specific settings, such as custom API bases. | |
| """ | |
| def __init__(self): | |
| self.providers = { | |
| "chutes": { | |
| "api_base": "https://llm.chutes.ai/v1", | |
| "model_prefix": "openai/" | |
| } | |
| } | |
| def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Returns provider-specific kwargs for a given model. | |
| """ | |
| model = kwargs.get("model") | |
| if not model: | |
| return kwargs | |
| provider = self._get_provider_from_model(model) | |
| provider_settings = self.providers.get(provider, {}) | |
| if "api_base" in provider_settings: | |
| kwargs["api_base"] = provider_settings["api_base"] | |
| if "model_prefix" in provider_settings: | |
| kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}" | |
| return kwargs | |
| def _get_provider_from_model(self, model: str) -> str: | |
| """ | |
| Determines the provider from the model name. | |
| """ | |
| return model.split('/')[0] | |