Mirrowel commited on
Commit
978e4a8
·
1 Parent(s): 2fdaebd

feat: Abort API retries on client disconnection

Browse files

Pass the FastAPI `request` object to the rotating client's `acompletion` and `aembedding` methods.
Add a check within the retry loops to detect client disconnections. If the client disconnects, immediately abort further retries to prevent unnecessary resource usage and provide faster feedback.

src/proxy_app/main.py CHANGED
@@ -239,7 +239,7 @@ async def chat_completions(
239
  request_data = await request.json()
240
  is_streaming = request_data.get("stream", False)
241
 
242
- response = await client.acompletion(**request_data)
243
 
244
  if is_streaming:
245
  # Wrap the streaming response to enable logging after it's complete
@@ -339,7 +339,7 @@ async def embeddings(
339
  if isinstance(request_data.get("input"), str):
340
  request_data["input"] = [request_data["input"]]
341
 
342
- response = await client.aembedding(**request_data)
343
 
344
  if ENABLE_REQUEST_LOGGING:
345
  response_summary = {
 
239
  request_data = await request.json()
240
  is_streaming = request_data.get("stream", False)
241
 
242
+ response = await client.acompletion(request=request, **request_data)
243
 
244
  if is_streaming:
245
  # Wrap the streaming response to enable logging after it's complete
 
339
  if isinstance(request_data.get("input"), str):
340
  request_data["input"] = [request_data["input"]]
341
 
342
+ response = await client.aembedding(request=request, **request_data)
343
 
344
  if ENABLE_REQUEST_LOGGING:
345
  response_summary = {
src/rotator_library/client.py CHANGED
@@ -176,7 +176,7 @@ class RotatingClient:
176
  yield "data: [DONE]\n\n"
177
 
178
 
179
- async def acompletion(self, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
180
  """
181
  Performs a completion call with smart key rotation and retry logic.
182
  It will try each available key in sequence if the previous one fails.
@@ -275,11 +275,17 @@ class RotatingClient:
275
 
276
  classified_error = classify_error(e)
277
 
 
 
278
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
279
  # These errors are not recoverable by rotating keys, so fail fast.
280
  lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
281
  raise last_exception
282
 
 
 
 
 
283
  if classified_error.error_type in ['server_error', 'api_connection']:
284
  # These are temporary, so record the failure and retry with backoff.
285
  await self.usage_manager.record_failure(current_key, model, classified_error)
@@ -310,7 +316,7 @@ class RotatingClient:
310
 
311
  raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
312
 
313
- async def aembedding(self, **kwargs) -> Any:
314
  """
315
  Performs an embedding call with smart key rotation and retry logic.
316
  """
@@ -366,6 +372,10 @@ class RotatingClient:
366
  lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
367
  raise last_exception
368
 
 
 
 
 
369
  if classified_error.error_type in ['server_error', 'api_connection']:
370
  await self.usage_manager.record_failure(current_key, model, classified_error)
371
 
 
176
  yield "data: [DONE]\n\n"
177
 
178
 
179
+ async def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
180
  """
181
  Performs a completion call with smart key rotation and retry logic.
182
  It will try each available key in sequence if the previous one fails.
 
275
 
276
  classified_error = classify_error(e)
277
 
278
+ classified_error = classify_error(e)
279
+
280
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
281
  # These errors are not recoverable by rotating keys, so fail fast.
282
  lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
283
  raise last_exception
284
 
285
+ if request and await request.is_disconnected():
286
+ lib_logger.warning(f"Client disconnected. Aborting retries for key ...{current_key[-4:]}.")
287
+ raise last_exception
288
+
289
  if classified_error.error_type in ['server_error', 'api_connection']:
290
  # These are temporary, so record the failure and retry with backoff.
291
  await self.usage_manager.record_failure(current_key, model, classified_error)
 
316
 
317
  raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
318
 
319
+ async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
320
  """
321
  Performs an embedding call with smart key rotation and retry logic.
322
  """
 
372
  lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
373
  raise last_exception
374
 
375
+ if request and await request.is_disconnected():
376
+ lib_logger.warning(f"Client disconnected during embedding. Aborting retries for key ...{current_key[-4:]}.")
377
+ raise last_exception
378
+
379
  if classified_error.error_type in ['server_error', 'api_connection']:
380
  await self.usage_manager.record_failure(current_key, model, classified_error)
381