File size: 6,079 Bytes
27b342a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5db0197
 
27b342a
 
 
5db0197
27b342a
 
5db0197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27b342a
 
 
 
 
5db0197
27b342a
 
 
 
 
 
 
 
 
 
5db0197
27b342a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf565ba
 
 
 
 
 
 
27b342a
bf565ba
 
 
4bbfff4
 
bf565ba
27b342a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import re
from typing import Optional, Dict, Any

from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError

class ClassifiedError:
    """A structured representation of a classified error."""
    def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
        self.error_type = error_type
        self.original_exception = original_exception
        self.status_code = status_code
        self.retry_after = retry_after

    def __str__(self):
        return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"

import json

def get_retry_after(error: Exception) -> Optional[int]:
    """
    Extracts the 'retry-after' duration in seconds from an exception message.
    Handles both integer and string representations of the duration, as well as JSON bodies.
    """
    error_str = str(error).lower()

    # 1. Try to parse JSON from the error string to find 'retryDelay'
    try:
        # It's common for the actual JSON to be embedded in the string representation
        json_match = re.search(r'(\{.*\})', error_str)
        if json_match:
            error_json = json.loads(json_match.group(1))
            retry_info = error_json.get('error', {}).get('details', [{}])[0]
            if retry_info.get('@type') == 'type.googleapis.com/google.rpc.RetryInfo':
                delay_str = retry_info.get('retryDelay', {}).get('seconds')
                if delay_str:
                    return int(delay_str)
                # Fallback for the other format
                delay_str = retry_info.get('retryDelay')
                if isinstance(delay_str, str) and delay_str.endswith('s'):
                    return int(delay_str[:-1])

    except (json.JSONDecodeError, IndexError, KeyError, TypeError):
        pass # If JSON parsing fails, proceed to regex and attribute checks

    # 2. Common regex patterns for 'retry-after'
    patterns = [
        r'retry after:?\s*(\d+)',
        r'retry_after:?\s*(\d+)',
        r'retry in\s*(\d+)\s*seconds',
        r'wait for\s*(\d+)\s*seconds',
        r'"retryDelay":\s*"(\d+)s"',
    ]
    
    for pattern in patterns:
        match = re.search(pattern, error_str)
        if match:
            try:
                return int(match.group(1))
            except (ValueError, IndexError):
                continue
    
    # 3. Handle cases where the error object itself has the attribute
    if hasattr(error, 'retry_after'):
        value = getattr(error, 'retry_after')
        if isinstance(value, int):
            return value
        if isinstance(value, str) and value.isdigit():
            return int(value)
            
    return None

def classify_error(e: Exception) -> ClassifiedError:
    """
    Classifies an exception into a structured ClassifiedError object.
    """
    status_code = getattr(e, 'status_code', None)
    
    if isinstance(e, RateLimitError):
        retry_after = get_retry_after(e)
        return ClassifiedError(
            error_type='rate_limit',
            original_exception=e,
            status_code=status_code or 429,
            retry_after=retry_after
        )
    
    if isinstance(e, (AuthenticationError,)):
        return ClassifiedError(
            error_type='authentication',
            original_exception=e,
            status_code=status_code or 401
        )
        
    if isinstance(e, (InvalidRequestError, BadRequestError)):
        return ClassifiedError(
            error_type='invalid_request',
            original_exception=e,
            status_code=status_code or 400
        )

    if isinstance(e, (ServiceUnavailableError, APIConnectionError, OpenAIError, InternalServerError)):
        # These are often temporary server-side issues
        return ClassifiedError(
            error_type='server_error',
            original_exception=e,
            status_code=status_code or 503
        )

    # Fallback for any other unclassified errors
    return ClassifiedError(
        error_type='unknown',
        original_exception=e,
        status_code=status_code
    )

def is_rate_limit_error(e: Exception) -> bool:
    """Checks if the exception is a rate limit error."""
    return isinstance(e, RateLimitError)

def is_server_error(e: Exception) -> bool:
    """Checks if the exception is a temporary server-side error."""
    return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError))

def is_unrecoverable_error(e: Exception) -> bool:
    """
    Checks if the exception is a non-retriable client-side error.
    These are errors that will not resolve on their own.
    """
    return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError))

class AllProviders:
    """
    A class to handle provider-specific settings, such as custom API bases.
    """
    def __init__(self):
        self.providers = {
            "chutes": {
                "api_base": "https://llm.chutes.ai/v1",
                "model_prefix": "openai/"
            }
        }

    def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]:
        """
        Returns provider-specific kwargs for a given model.
        """
        model = kwargs.get("model")
        if not model:
            return kwargs

        provider = self._get_provider_from_model(model)
        provider_settings = self.providers.get(provider, {})
        
        if "api_base" in provider_settings:
            kwargs["api_base"] = provider_settings["api_base"]
        
        if "model_prefix" in provider_settings:
            kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}"
            
        return kwargs

    def _get_provider_from_model(self, model: str) -> str:
        """
        Determines the provider from the model name.
        """
        return model.split('/')[0]