Mirrowel commited on
Commit
a725feb
·
1 Parent(s): c5716c1

refactor(client): 🔨 add comprehensive error handling and retry logic for custom provider non-streaming calls

Browse files

This change brings the non-streaming custom provider call path in line with the streaming path's robust error handling strategy.

- Implements a retry loop with attempt tracking and logging for custom provider calls
- Adds pre-request callback execution with configurable error handling
- Integrates error classification and rotation logic for rate limits, HTTP errors, and server errors
- Records errors in the accumulator for client-level reporting and visibility
- Implements exponential backoff with jitter for transient server errors
- Adds cooldown management for rate-limited providers
- Respects time budget constraints when calculating retry wait times
- Properly manages credential state (success/failure recording and key release)
- Distinguishes between recoverable errors (which trigger rotation) and non-recoverable errors (which fail immediately)

The retry loop handles three categories of exceptions:
1. Rate limits and HTTP status errors: trigger immediate rotation after recording
2. Connection and server errors: retry with backoff, rotate only after max retries
3. General exceptions: classify and rotate if recoverable, fail if not

Files changed (1) hide show
  1. src/rotator_library/client.py +178 -12
src/rotator_library/client.py CHANGED
@@ -1065,19 +1065,185 @@ class RotatingClient:
1065
  is_budget_enabled
1066
  )
1067
 
1068
- # The plugin handles the entire call, including retries on 401, etc.
1069
- # The main retry loop here is for key rotation on other errors.
1070
- response = await provider_plugin.acompletion(
1071
- self.http_client, **litellm_kwargs
1072
- )
 
1073
 
1074
- # For non-streaming, success is immediate, and this function only handles non-streaming.
1075
- await self.usage_manager.record_success(
1076
- current_cred, model, response
1077
- )
1078
- await self.usage_manager.release_key(current_cred, model)
1079
- key_acquired = False
1080
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1081
 
1082
  else: # This is the standard API Key / litellm-handled provider logic
1083
  is_oauth = provider in self.oauth_providers
 
1065
  is_budget_enabled
1066
  )
1067
 
1068
+ # Retry loop for custom providers - mirrors streaming path error handling
1069
+ for attempt in range(self.max_retries):
1070
+ try:
1071
+ lib_logger.info(
1072
+ f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})"
1073
+ )
1074
 
1075
+ if pre_request_callback:
1076
+ try:
1077
+ await pre_request_callback(request, litellm_kwargs)
1078
+ except Exception as e:
1079
+ if self.abort_on_callback_error:
1080
+ raise PreRequestCallbackError(
1081
+ f"Pre-request callback failed: {e}"
1082
+ ) from e
1083
+ else:
1084
+ lib_logger.warning(
1085
+ f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
1086
+ )
1087
+
1088
+ response = await provider_plugin.acompletion(
1089
+ self.http_client, **litellm_kwargs
1090
+ )
1091
+
1092
+ # For non-streaming, success is immediate
1093
+ await self.usage_manager.record_success(
1094
+ current_cred, model, response
1095
+ )
1096
+ await self.usage_manager.release_key(current_cred, model)
1097
+ key_acquired = False
1098
+ return response
1099
+
1100
+ except (
1101
+ litellm.RateLimitError,
1102
+ httpx.HTTPStatusError,
1103
+ ) as e:
1104
+ last_exception = e
1105
+ classified_error = classify_error(e, provider=provider)
1106
+ error_message = str(e).split("\n")[0]
1107
+
1108
+ log_failure(
1109
+ api_key=current_cred,
1110
+ model=model,
1111
+ attempt=attempt + 1,
1112
+ error=e,
1113
+ request_headers=dict(request.headers)
1114
+ if request
1115
+ else {},
1116
+ )
1117
+
1118
+ # Record in accumulator for client reporting
1119
+ error_accumulator.record_error(
1120
+ current_cred, classified_error, error_message
1121
+ )
1122
+
1123
+ # Check if this error should trigger rotation
1124
+ if not should_rotate_on_error(classified_error):
1125
+ lib_logger.error(
1126
+ f"Non-recoverable error ({classified_error.error_type}) during custom provider call. Failing."
1127
+ )
1128
+ raise last_exception
1129
+
1130
+ # Handle rate limits with cooldown (exclude quota_exceeded)
1131
+ if classified_error.error_type == "rate_limit":
1132
+ cooldown_duration = classified_error.retry_after or 60
1133
+ await self.cooldown_manager.start_cooldown(
1134
+ provider, cooldown_duration
1135
+ )
1136
+
1137
+ await self.usage_manager.record_failure(
1138
+ current_cred, model, classified_error
1139
+ )
1140
+ lib_logger.warning(
1141
+ f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating."
1142
+ )
1143
+ break # Rotate to next credential
1144
+
1145
+ except (
1146
+ APIConnectionError,
1147
+ litellm.InternalServerError,
1148
+ litellm.ServiceUnavailableError,
1149
+ ) as e:
1150
+ last_exception = e
1151
+ log_failure(
1152
+ api_key=current_cred,
1153
+ model=model,
1154
+ attempt=attempt + 1,
1155
+ error=e,
1156
+ request_headers=dict(request.headers)
1157
+ if request
1158
+ else {},
1159
+ )
1160
+ classified_error = classify_error(e, provider=provider)
1161
+ error_message = str(e).split("\n")[0]
1162
+
1163
+ # Provider-level error: don't increment consecutive failures
1164
+ await self.usage_manager.record_failure(
1165
+ current_cred,
1166
+ model,
1167
+ classified_error,
1168
+ increment_consecutive_failures=False,
1169
+ )
1170
+
1171
+ if attempt >= self.max_retries - 1:
1172
+ error_accumulator.record_error(
1173
+ current_cred, classified_error, error_message
1174
+ )
1175
+ lib_logger.warning(
1176
+ f"Cred {mask_credential(current_cred)} failed after max retries. Rotating."
1177
+ )
1178
+ break
1179
+
1180
+ wait_time = classified_error.retry_after or (
1181
+ 2**attempt
1182
+ ) + random.uniform(0, 1)
1183
+ remaining_budget = deadline - time.time()
1184
+ if wait_time > remaining_budget:
1185
+ error_accumulator.record_error(
1186
+ current_cred, classified_error, error_message
1187
+ )
1188
+ lib_logger.warning(
1189
+ f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating."
1190
+ )
1191
+ break
1192
+
1193
+ lib_logger.warning(
1194
+ f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s."
1195
+ )
1196
+ await asyncio.sleep(wait_time)
1197
+ continue
1198
+
1199
+ except Exception as e:
1200
+ last_exception = e
1201
+ log_failure(
1202
+ api_key=current_cred,
1203
+ model=model,
1204
+ attempt=attempt + 1,
1205
+ error=e,
1206
+ request_headers=dict(request.headers)
1207
+ if request
1208
+ else {},
1209
+ )
1210
+ classified_error = classify_error(e, provider=provider)
1211
+ error_message = str(e).split("\n")[0]
1212
+
1213
+ # Record in accumulator
1214
+ error_accumulator.record_error(
1215
+ current_cred, classified_error, error_message
1216
+ )
1217
+
1218
+ lib_logger.warning(
1219
+ f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})."
1220
+ )
1221
+
1222
+ # Check if this error should trigger rotation
1223
+ if not should_rotate_on_error(classified_error):
1224
+ lib_logger.error(
1225
+ f"Non-recoverable error ({classified_error.error_type}). Failing."
1226
+ )
1227
+ raise last_exception
1228
+
1229
+ # Handle rate limits with cooldown (exclude quota_exceeded)
1230
+ if (
1231
+ classified_error.status_code == 429
1232
+ and classified_error.error_type != "quota_exceeded"
1233
+ ) or classified_error.error_type == "rate_limit":
1234
+ cooldown_duration = classified_error.retry_after or 60
1235
+ await self.cooldown_manager.start_cooldown(
1236
+ provider, cooldown_duration
1237
+ )
1238
+
1239
+ await self.usage_manager.record_failure(
1240
+ current_cred, model, classified_error
1241
+ )
1242
+ break # Rotate to next credential
1243
+
1244
+ # If the inner loop breaks, it means the key failed and we need to rotate.
1245
+ # Continue to the next iteration of the outer while loop to pick a new key.
1246
+ continue
1247
 
1248
  else: # This is the standard API Key / litellm-handled provider logic
1249
  is_oauth = provider in self.oauth_providers