Mirrowel commited on
Commit
3f958d9
·
1 Parent(s): b5b51f2

Refactor: Streaming response handling and key management.

Browse files

Adjust `chat_completions` in `proxy_app` to correctly differentiate and handle streaming vs. non-streaming responses from the client. This ensures the generator is passed directly for streaming and the awaited result for non-streaming.

In `rotator_library`, remove the immediate `release_key` call after a successful `aembedding` operation, indicating a change in how API keys are managed post-use. Add explicit handling for `asyncio.CancelledError` during retries.

src/proxy_app/main.py CHANGED
@@ -243,16 +243,14 @@ async def chat_completions(
243
  request_data = await request.json()
244
  is_streaming = request_data.get("stream", False)
245
 
246
- response = await client.acompletion(request=request, **request_data)
247
-
248
  if is_streaming:
249
- # For streaming, the response is the generator.
250
  return StreamingResponse(
251
- streaming_response_wrapper(request, request_data, response),
252
  media_type="text/event-stream"
253
  )
254
  else:
255
- # For non-streaming, the response is the completed object.
256
  if ENABLE_REQUEST_LOGGING:
257
  log_request_response(
258
  request_data=request_data,
 
243
  request_data = await request.json()
244
  is_streaming = request_data.get("stream", False)
245
 
 
 
246
  if is_streaming:
247
+ response_generator = client.acompletion(request=request, **request_data)
248
  return StreamingResponse(
249
+ streaming_response_wrapper(request, request_data, response_generator),
250
  media_type="text/event-stream"
251
  )
252
  else:
253
+ response = await client.acompletion(request=request, **request_data)
254
  if ENABLE_REQUEST_LOGGING:
255
  log_request_response(
256
  request_data=request_data,
src/rotator_library/client.py CHANGED
@@ -197,9 +197,6 @@ class RotatingClient:
197
  raise Exception("Failed to complete the request: No available API keys or all keys failed.")
198
 
199
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
200
- """
201
- Performs an embedding call with smart key rotation and retry logic.
202
- """
203
  kwargs = self._convert_model_params(**kwargs)
204
  model = kwargs.get("model")
205
  if not model:
@@ -221,10 +218,7 @@ class RotatingClient:
221
  if not keys_to_try:
222
  break
223
 
224
- current_key = await self.usage_manager.acquire_key(
225
- available_keys=keys_to_try,
226
- model=model
227
- )
228
  key_acquired = True
229
  tried_keys.add(current_key)
230
 
@@ -234,45 +228,41 @@ class RotatingClient:
234
  for attempt in range(self.max_retries):
235
  try:
236
  lib_logger.info(f"Attempting embedding call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
237
-
238
  response = await litellm.aembedding(api_key=current_key, **litellm_kwargs)
239
 
240
  await self.usage_manager.record_success(current_key, model, response)
241
- await self.usage_manager.release_key(current_key, model)
242
- key_acquired = False
243
  return response
244
 
245
  except Exception as e:
246
  last_exception = e
 
 
247
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
248
-
249
  classified_error = classify_error(e)
250
 
251
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
252
- lib_logger.error(f"Unrecoverable error '{classified_error.error_type}' with key ...{current_key[-4:]}. Failing request.")
253
  raise last_exception
254
-
255
  if request and await request.is_disconnected():
256
  lib_logger.warning(f"Client disconnected during embedding. Aborting retries for key ...{current_key[-4:]}.")
257
  raise last_exception
258
 
259
  if classified_error.error_type in ['server_error', 'api_connection']:
260
  await self.usage_manager.record_failure(current_key, model, classified_error)
261
-
262
- if attempt >= self.max_retries - 1:
263
- lib_logger.warning(f"Key ...{current_key[-4:]} failed on final retry for {classified_error.error_type}. Trying next key.")
264
- break
265
-
266
- base_wait = 5 if classified_error.error_type == 'api_connection' else 1
267
- wait_time = classified_error.retry_after or (base_wait * (2 ** attempt)) + random.uniform(0, 1)
268
-
269
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered a {classified_error.error_type}. Retrying in {wait_time:.2f} seconds...")
270
  await asyncio.sleep(wait_time)
271
  continue
272
-
273
  await self.usage_manager.record_failure(current_key, model, classified_error)
274
- lib_logger.warning(f"Key ...{current_key[-4:]} encountered '{classified_error.error_type}'. Trying next key.")
275
  break
 
 
 
 
 
 
 
276
  finally:
277
  if key_acquired and current_key:
278
  await self.usage_manager.release_key(current_key, model)
@@ -280,7 +270,7 @@ class RotatingClient:
280
  if last_exception:
281
  raise last_exception
282
 
283
- raise Exception("Failed to complete the request: No available API keys for the provider or all keys failed.")
284
 
285
  def token_count(self, **kwargs) -> int:
286
  """Calculates the number of tokens for a given text or list of messages."""
 
197
  raise Exception("Failed to complete the request: No available API keys or all keys failed.")
198
 
199
  async def aembedding(self, request: Optional[Any] = None, **kwargs) -> Any:
 
 
 
200
  kwargs = self._convert_model_params(**kwargs)
201
  model = kwargs.get("model")
202
  if not model:
 
218
  if not keys_to_try:
219
  break
220
 
221
+ current_key = await self.usage_manager.acquire_key(available_keys=keys_to_try, model=model)
 
 
 
222
  key_acquired = True
223
  tried_keys.add(current_key)
224
 
 
228
  for attempt in range(self.max_retries):
229
  try:
230
  lib_logger.info(f"Attempting embedding call with key ...{current_key[-4:]} (Attempt {attempt + 1}/{self.max_retries})")
 
231
  response = await litellm.aembedding(api_key=current_key, **litellm_kwargs)
232
 
233
  await self.usage_manager.record_success(current_key, model, response)
 
 
234
  return response
235
 
236
  except Exception as e:
237
  last_exception = e
238
+ if isinstance(e, asyncio.CancelledError): raise e
239
+
240
  log_failure(api_key=current_key, model=model, attempt=attempt + 1, error=e, request_data=kwargs)
 
241
  classified_error = classify_error(e)
242
 
243
  if classified_error.error_type in ['invalid_request', 'context_window_exceeded']:
 
244
  raise last_exception
245
+
246
  if request and await request.is_disconnected():
247
  lib_logger.warning(f"Client disconnected during embedding. Aborting retries for key ...{current_key[-4:]}.")
248
  raise last_exception
249
 
250
  if classified_error.error_type in ['server_error', 'api_connection']:
251
  await self.usage_manager.record_failure(current_key, model, classified_error)
252
+ if attempt >= self.max_retries - 1: break
253
+ wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
 
 
 
 
 
 
 
254
  await asyncio.sleep(wait_time)
255
  continue
256
+
257
  await self.usage_manager.record_failure(current_key, model, classified_error)
 
258
  break
259
+
260
+ except (litellm.InvalidRequestError, litellm.ContextWindowExceededError, asyncio.CancelledError) as e:
261
+ raise e
262
+ except Exception as e:
263
+ last_exception = e
264
+ lib_logger.error(f"An unexpected error occurred with key ...{current_key[-4:] if current_key else 'N/A'}: {e}")
265
+ continue
266
  finally:
267
  if key_acquired and current_key:
268
  await self.usage_manager.release_key(current_key, model)
 
270
  if last_exception:
271
  raise last_exception
272
 
273
+ raise Exception("Failed to complete the request: No available API keys or all keys failed.")
274
 
275
  def token_count(self, **kwargs) -> int:
276
  """Calculates the number of tokens for a given text or list of messages."""