Mirrowel commited on
Commit
225c46e
·
1 Parent(s): 8a28cd0

feat(providers): ✨ enable dynamic support for custom openai-compatible providers

Browse files

This change introduces a mechanism to automatically detect and register custom OpenAI-compatible providers based on environment variables (e.g., `CUSTOM_PROVIDER_API_BASE`).

- Adds `OpenAICompatibleProvider` class for generic endpoint interactions.
- Updates `AllProviders` (in `error_handler`) to dynamically discover and register these providers by scanning the environment variables.
- Modifies `RotatingClient` to lazily instantiate the compatible provider when a provider name is used and its API base is configured.

This allows users to integrate with any standards-compliant API without explicit internal registration.

src/rotator_library/client.py CHANGED
@@ -12,7 +12,7 @@ from litellm.litellm_core_utils.token_counter import token_counter
12
  import logging
13
  from typing import List, Dict, Any, AsyncGenerator, Optional, Union
14
 
15
- lib_logger = logging.getLogger('rotator_library')
16
  # Ensure the logger is configured to propagate to the root logger
17
  # which is set up in main.py. This allows the main app to control
18
  # log levels and handlers centrally.
@@ -20,24 +20,34 @@ lib_logger.propagate = False
20
 
21
  from .usage_manager import UsageManager
22
  from .failure_logger import log_failure
23
- from .error_handler import PreRequestCallbackError, classify_error, AllProviders, NoAvailableKeysError
 
 
 
 
 
24
  from .providers import PROVIDER_PLUGINS
 
25
  from .request_sanitizer import sanitize_request_payload
26
  from .cooldown_manager import CooldownManager
27
  from .credential_manager import CredentialManager
28
  from .background_refresher import BackgroundRefresher
29
 
 
30
  class StreamedAPIError(Exception):
31
  """Custom exception to signal an API error received over a stream."""
 
32
  def __init__(self, message, data=None):
33
  super().__init__(message)
34
  self.data = data
35
 
 
36
  class RotatingClient:
37
  """
38
  A client that intelligently rotates and retries API keys using LiteLLM,
39
  with support for both streaming and non-streaming responses.
40
  """
 
41
  def __init__(
42
  self,
43
  api_keys: Optional[Dict[str, List[str]]] = None,
@@ -50,7 +60,7 @@ class RotatingClient:
50
  litellm_provider_params: Optional[Dict[str, Any]] = None,
51
  ignore_models: Optional[Dict[str, List[str]]] = None,
52
  whitelist_models: Optional[Dict[str, List[str]]] = None,
53
- enable_request_logging: bool = False
54
  ):
55
  os.environ["LITELLM_LOG"] = "ERROR"
56
  litellm.set_verbose = False
@@ -71,24 +81,28 @@ class RotatingClient:
71
 
72
  # Filter out providers with empty lists of credentials to ensure validity
73
  api_keys = {provider: keys for provider, keys in api_keys.items() if keys}
74
- oauth_credentials = {provider: paths for provider, paths in oauth_credentials.items() if paths}
 
 
75
 
76
  if not api_keys and not oauth_credentials:
77
- raise ValueError("No valid credentials provided. Either 'api_keys' or 'oauth_credentials' must be provided and non-empty.")
 
 
78
 
79
  self.api_keys = api_keys
80
  self.credential_manager = CredentialManager(oauth_credentials)
81
  self.oauth_credentials = self.credential_manager.discover_and_prepare()
82
  self.background_refresher = BackgroundRefresher(self)
83
  self.oauth_providers = set(self.oauth_credentials.keys())
84
-
85
  all_credentials = {}
86
  for provider, keys in api_keys.items():
87
  all_credentials.setdefault(provider, []).extend(keys)
88
  for provider, paths in self.oauth_credentials.items():
89
  all_credentials.setdefault(provider, []).extend(paths)
90
  self.all_credentials = all_credentials
91
-
92
  self.max_retries = max_retries
93
  self.global_timeout = global_timeout
94
  self.abort_on_callback_error = abort_on_callback_error
@@ -109,29 +123,32 @@ class RotatingClient:
109
  Checks if a model should be ignored based on the ignore list.
110
  Supports exact and partial matching for both full model IDs and model names.
111
  """
112
- model_provider = model_id.split('/')[0]
113
  if model_provider not in self.ignore_models:
114
  return False
115
 
116
  ignore_list = self.ignore_models[model_provider]
117
- if ignore_list == ['*']:
118
  return True
119
-
120
  try:
121
  # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
122
- provider_model_name = model_id.split('/', 1)[1]
123
  except IndexError:
124
  provider_model_name = model_id
125
 
126
  for ignored_pattern in ignore_list:
127
- if ignored_pattern.endswith('*'):
128
  match_pattern = ignored_pattern[:-1]
129
  # Match wildcard against the provider's model name
130
  if provider_model_name.startswith(match_pattern):
131
  return True
132
  else:
133
  # Exact match against the full proxy ID OR the provider's model name
134
- if model_id == ignored_pattern or provider_model_name == ignored_pattern:
 
 
 
135
  return True
136
  return False
137
 
@@ -140,29 +157,32 @@ class RotatingClient:
140
  Checks if a model is explicitly whitelisted.
141
  Supports exact and partial matching for both full model IDs and model names.
142
  """
143
- model_provider = model_id.split('/')[0]
144
  if model_provider not in self.whitelist_models:
145
  return False
146
 
147
  whitelist = self.whitelist_models[model_provider]
148
  for whitelisted_pattern in whitelist:
149
- if whitelisted_pattern == '*':
150
  return True
151
-
152
  try:
153
  # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
154
- provider_model_name = model_id.split('/', 1)[1]
155
  except IndexError:
156
  provider_model_name = model_id
157
 
158
- if whitelisted_pattern.endswith('*'):
159
  match_pattern = whitelisted_pattern[:-1]
160
  # Match wildcard against the provider's model name
161
  if provider_model_name.startswith(match_pattern):
162
  return True
163
  else:
164
  # Exact match against the full proxy ID OR the provider's model name
165
- if model_id == whitelisted_pattern or provider_model_name == whitelisted_pattern:
 
 
 
166
  return True
167
  return False
168
 
@@ -176,10 +196,16 @@ class RotatingClient:
176
 
177
  # Keys to remove at any level of the dictionary
178
  keys_to_pop = [
179
- "messages", "input", "response", "data", "api_key",
180
- "api_base", "original_response", "additional_args"
 
 
 
 
 
 
181
  ]
182
-
183
  # Keys that might contain nested dictionaries to clean
184
  nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"]
185
 
@@ -193,12 +219,12 @@ class RotatingClient:
193
  # Remove sensitive/large keys
194
  for key in keys_to_pop:
195
  data_dict.pop(key, None)
196
-
197
  # Recursively clean nested dictionaries
198
  for key in nested_keys:
199
  if key in data_dict and isinstance(data_dict[key], dict):
200
  clean_recursively(data_dict[key])
201
-
202
  # Also iterate through all values to find any other nested dicts
203
  for key, value in list(data_dict.items()):
204
  if isinstance(value, dict):
@@ -217,22 +243,26 @@ class RotatingClient:
217
  log_event_type = log_data.get("log_event_type")
218
  if log_event_type in ["pre_api_call", "post_api_call"]:
219
  return # Skip these verbose logs entirely
220
-
221
  # For successful calls or pre-call logs, a simple debug message is enough.
222
  if not log_data.get("exception"):
223
  sanitized_log = self._sanitize_litellm_log(log_data)
224
  # We log it at the DEBUG level to ensure it goes to the debug file
225
- # and not the console, based on the main.py configuration.
226
  lib_logger.debug(f"LiteLLM Log: {sanitized_log}")
227
  return
228
 
229
  # For failures, extract key info to make debug logs more readable.
230
  model = log_data.get("model", "N/A")
231
  call_id = log_data.get("litellm_call_id", "N/A")
232
- error_info = log_data.get("standard_logging_object", {}).get("error_information", {})
 
 
233
  error_class = error_info.get("error_class", "UnknownError")
234
- error_message = error_info.get("error_message", str(log_data.get("exception", "")))
235
- error_message = ' '.join(error_message.split()) # Sanitize
 
 
236
 
237
  lib_logger.debug(
238
  f"LiteLLM Callback Handled Error: Model={model} | "
@@ -247,7 +277,7 @@ class RotatingClient:
247
 
248
  async def close(self):
249
  """Close the HTTP client to prevent resource leaks."""
250
- if hasattr(self, 'http_client') and self.http_client:
251
  await self.http_client.aclose()
252
 
253
  def _convert_model_params(self, **kwargs) -> Dict[str, Any]:
@@ -260,26 +290,47 @@ class RotatingClient:
260
  if not model:
261
  return kwargs
262
 
263
- provider = model.split('/')[0]
264
  if provider == "chutes":
265
  kwargs["model"] = f"openai/{model.split('/', 1)[1]}"
266
  kwargs["api_base"] = "https://llm.chutes.ai/v1"
267
-
268
  return kwargs
269
 
270
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
271
  return self.oauth_credentials
272
 
 
 
 
 
 
 
 
 
273
  def _get_provider_instance(self, provider_name: str):
274
  """Lazily initializes and returns a provider instance."""
275
  if provider_name not in self._provider_instances:
276
  if provider_name in self._provider_plugins:
277
- self._provider_instances[provider_name] = self._provider_plugins[provider_name]()
 
 
 
 
 
 
 
 
 
 
 
278
  else:
279
  return None
280
  return self._provider_instances[provider_name]
281
 
282
- async def _safe_streaming_wrapper(self, stream: Any, key: str, model: str, request: Optional[Any] = None) -> AsyncGenerator[Any, None]:
 
 
283
  """
284
  A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
285
  and distinguishes between content and streamed errors.
@@ -292,7 +343,9 @@ class RotatingClient:
292
  try:
293
  while True:
294
  if request and await request.is_disconnected():
295
- lib_logger.info(f"Client disconnected. Aborting stream for credential ...{key[-6:]}.")
 
 
296
  # Do not yield [DONE] because the client is gone.
297
  # The 'finally' block will handle key release.
298
  break
@@ -302,32 +355,47 @@ class RotatingClient:
302
  if json_buffer:
303
  # If we are about to discard a buffer, it means data was likely lost.
304
  # Log this as a warning to make it visible.
305
- lib_logger.warning(f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}")
 
 
306
  json_buffer = ""
307
-
308
  yield f"data: {json.dumps(chunk.dict())}\n\n"
309
 
310
- if hasattr(chunk, 'usage') and chunk.usage:
311
- last_usage = chunk.usage # Overwrite with the latest (cumulative)
 
 
312
 
313
  except StopAsyncIteration:
314
  stream_completed = True
315
  if json_buffer:
316
- lib_logger.info(f"Stream ended with incomplete data in buffer: {json_buffer}")
 
 
317
  if last_usage:
318
  # Create a dummy ModelResponse for recording (only usage matters)
319
  dummy_response = litellm.ModelResponse(usage=last_usage)
320
- await self.usage_manager.record_success(key, model, dummy_response)
 
 
321
  else:
322
  # If no usage seen (rare), record success without tokens/cost
323
  await self.usage_manager.record_success(key, model)
324
  break
325
 
326
- except (litellm.RateLimitError, litellm.ServiceUnavailableError, litellm.InternalServerError, APIConnectionError) as e:
 
 
 
 
 
327
  # This is a critical, typed error from litellm that signals a key failure.
328
  # We do not try to parse it here. We wrap it and raise it immediately
329
  # for the outer retry loop to handle.
330
- lib_logger.warning(f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation.")
 
 
331
  raise StreamedAPIError("Provider error received in stream", data=e)
332
 
333
  except Exception as e:
@@ -338,13 +406,17 @@ class RotatingClient:
338
  match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL)
339
  if match:
340
  # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it.
341
- raw_chunk = codecs.decode(match.group(1), 'unicode_escape')
342
  else:
343
  # Fallback for other potential error formats that use "Received chunk:".
344
- chunk_from_split = str(e).split("Received chunk:")[-1].strip()
345
- if chunk_from_split != str(e): # Ensure the split actually did something
 
 
 
 
346
  raw_chunk = chunk_from_split
347
-
348
  if not raw_chunk:
349
  # If we could not extract a valid chunk, we cannot proceed with reassembly.
350
  # This indicates a different, unexpected error type. Re-raise it.
@@ -353,26 +425,36 @@ class RotatingClient:
353
  # Append the clean chunk to the buffer and try to parse.
354
  json_buffer += raw_chunk
355
  parsed_data = json.loads(json_buffer)
356
-
357
  # If parsing succeeds, we have the complete object.
358
- lib_logger.info(f"Successfully reassembled JSON from stream: {json_buffer}")
359
-
 
 
360
  # Wrap the complete error object and raise it. The outer function will decide how to handle it.
361
- raise StreamedAPIError("Provider error received in stream", data=parsed_data)
 
 
362
 
363
  except json.JSONDecodeError:
364
  # This is the expected outcome if the JSON in the buffer is not yet complete.
365
- lib_logger.info(f"Buffer still incomplete. Waiting for more chunks: {json_buffer}")
366
- continue # Continue to the next loop to get the next chunk.
 
 
367
  except StreamedAPIError:
368
  # Re-raise to be caught by the outer retry handler.
369
  raise
370
  except Exception as buffer_exc:
371
  # If the error was not a JSONDecodeError, it's an unexpected internal error.
372
- lib_logger.error(f"Error during stream buffering logic: {buffer_exc}. Discarding buffer.")
373
- json_buffer = "" # Clear the corrupted buffer to prevent further issues.
 
 
 
 
374
  raise buffer_exc
375
-
376
  except StreamedAPIError:
377
  # This is caught by the acompletion retry logic.
378
  # We re-raise it to ensure it's not caught by the generic 'except Exception'.
@@ -381,7 +463,9 @@ class RotatingClient:
381
  except Exception as e:
382
  # Catch any other unexpected errors during streaming.
383
  lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}")
384
- lib_logger.error(f"An unexpected error occurred during the stream for credential ...{key[-6:]}: {e}")
 
 
385
  # We still need to raise it so the client knows something went wrong.
386
  raise
387
 
@@ -389,219 +473,338 @@ class RotatingClient:
389
  # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.).
390
  # The primary goal is to ensure usage is always logged internally.
391
  await self.usage_manager.release_key(key, model)
392
- lib_logger.info(f"STREAM FINISHED and lock released for credential ...{key[-6:]}.")
393
-
 
 
394
  # Only send [DONE] if the stream completed naturally and the client is still there.
395
  # This prevents sending [DONE] to a disconnected client or after an error.
396
- if stream_completed and (not request or not await request.is_disconnected()):
 
 
397
  yield "data: [DONE]\n\n"
398
 
399
- async def _execute_with_retry(self, api_call: callable, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
 
 
 
 
 
 
400
  """A generic retry mechanism for non-streaming API calls."""
401
  model = kwargs.get("model")
402
  if not model:
403
  raise ValueError("'model' is a required parameter.")
404
 
405
- provider = model.split('/')[0]
406
  if provider not in self.all_credentials:
407
- raise ValueError(f"No API keys or OAuth credentials configured for provider: {provider}")
 
 
408
 
409
  # Establish a global deadline for the entire request lifecycle.
410
  deadline = time.time() + self.global_timeout
411
-
412
  # Create a mutable copy of the keys and shuffle it to ensure
413
  # that the key selection is randomized, which is crucial when
414
  # multiple keys have the same usage stats.
415
  credentials_for_provider = list(self.all_credentials[provider])
416
  random.shuffle(credentials_for_provider)
417
-
418
  tried_creds = set()
419
  last_exception = None
420
  kwargs = self._convert_model_params(**kwargs)
421
 
422
  # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
423
- while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
 
 
424
  current_cred = None
425
  key_acquired = False
426
  try:
427
  # Check for a provider-wide cooldown first.
428
  if await self.cooldown_manager.is_cooling_down(provider):
429
- remaining_cooldown = await self.cooldown_manager.get_cooldown_remaining(provider)
 
 
430
  remaining_budget = deadline - time.time()
431
-
432
  # If the cooldown is longer than the remaining time budget, fail fast.
433
  if remaining_cooldown > remaining_budget:
434
- lib_logger.warning(f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early.")
 
 
435
  break
436
 
437
- lib_logger.warning(f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds.")
 
 
438
  await asyncio.sleep(remaining_cooldown)
439
 
440
- creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
 
 
441
  if not creds_to_try:
442
  break
443
 
444
- lib_logger.info(f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}")
 
 
445
  current_cred = await self.usage_manager.acquire_key(
446
- available_keys=creds_to_try,
447
- model=model,
448
- deadline=deadline
449
  )
450
  key_acquired = True
451
  tried_creds.add(current_cred)
452
 
453
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
454
-
455
  # [NEW] Merge provider-specific params
456
  if provider in self.litellm_provider_params:
457
  litellm_kwargs["litellm_params"] = {
458
  **self.litellm_provider_params[provider],
459
- **litellm_kwargs.get("litellm_params", {})
460
  }
461
-
462
  provider_plugin = self._get_provider_instance(provider)
463
  if provider_plugin and provider_plugin.has_custom_logic():
464
- lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
 
 
465
  litellm_kwargs["credential_identifier"] = current_cred
466
- litellm_kwargs["enable_request_logging"] = self.enable_request_logging
 
 
467
 
468
  # Check body first for custom_reasoning_budget
469
  if "custom_reasoning_budget" in kwargs:
470
- litellm_kwargs["custom_reasoning_budget"] = kwargs["custom_reasoning_budget"]
 
 
471
  else:
472
  custom_budget_header = None
473
- if request and hasattr(request, 'headers'):
474
- custom_budget_header = request.headers.get("custom_reasoning_budget")
 
 
475
 
476
  if custom_budget_header is not None:
477
- is_budget_enabled = custom_budget_header.lower() == 'true'
478
- litellm_kwargs["custom_reasoning_budget"] = is_budget_enabled
479
-
 
 
480
  # The plugin handles the entire call, including retries on 401, etc.
481
  # The main retry loop here is for key rotation on other errors.
482
- response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
483
-
 
 
484
  # For non-streaming, success is immediate, and this function only handles non-streaming.
485
- await self.usage_manager.record_success(current_cred, model, response)
 
 
486
  await self.usage_manager.release_key(current_cred, model)
487
  key_acquired = False
488
  return response
489
 
490
- else: # This is the standard API Key / litellm-handled provider logic
491
  is_oauth = provider in self.oauth_providers
492
- if is_oauth: # Standard OAuth provider (not custom)
493
  # ... (logic to set headers) ...
494
  pass
495
- else: # API Key
496
  litellm_kwargs["api_key"] = current_cred
497
-
498
  provider_instance = self._get_provider_instance(provider)
499
  if provider_instance:
500
  if "safety_settings" in litellm_kwargs:
501
- converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
 
 
 
 
502
  if converted_settings is not None:
503
  litellm_kwargs["safety_settings"] = converted_settings
504
  else:
505
  del litellm_kwargs["safety_settings"]
506
-
507
  if provider == "gemini" and provider_instance:
508
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
509
  if provider == "nvidia_nim" and provider_instance:
510
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
511
 
512
  if "gemma-3" in model and "messages" in litellm_kwargs:
513
- litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
514
-
 
 
 
 
 
515
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
516
 
517
  for attempt in range(self.max_retries):
518
  try:
519
- lib_logger.info(f"Attempting call with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
520
-
 
 
521
  if pre_request_callback:
522
  try:
523
  await pre_request_callback(request, litellm_kwargs)
524
  except Exception as e:
525
  if self.abort_on_callback_error:
526
- raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
 
 
527
  else:
528
- lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
529
-
 
 
530
  response = await api_call(
531
  **litellm_kwargs,
532
- logger_fn=self._litellm_logger_callback
 
 
 
 
533
  )
534
-
535
- await self.usage_manager.record_success(current_cred, model, response)
536
  await self.usage_manager.release_key(current_cred, model)
537
  key_acquired = False
538
  return response
539
 
540
  except litellm.RateLimitError as e:
541
  last_exception = e
542
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
543
  classified_error = classify_error(e)
544
-
545
  # Extract a clean error message for the user-facing log
546
- error_message = str(e).split('\n')[0]
547
- lib_logger.info(f"Key ...{current_cred[-6:]} hit rate limit for model {model}. Reason: '{error_message}'. Rotating key.")
 
 
548
 
549
  if classified_error.status_code == 429:
550
  cooldown_duration = classified_error.retry_after or 60
551
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
552
- lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
553
-
554
- await self.usage_manager.record_failure(current_cred, model, classified_error)
555
- lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a rate limit. Trying next key.")
556
- break # Move to the next key
 
 
 
 
 
 
 
 
557
 
558
- except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
 
 
 
 
559
  last_exception = e
560
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
561
  classified_error = classify_error(e)
562
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
563
 
564
  if attempt >= self.max_retries - 1:
565
- error_message = str(e).split('\n')[0]
566
- lib_logger.warning(f"Key ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Reason: '{error_message}'. Rotating key.")
567
- break # Move to the next key
568
-
 
 
569
  # For temporary errors, wait before retrying with the same key.
570
- wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
 
 
571
  remaining_budget = deadline - time.time()
572
-
573
  # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
574
  if wait_time > remaining_budget:
575
- lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
 
 
576
  break
577
 
578
- error_message = str(e).split('\n')[0]
579
- lib_logger.warning(f"Key ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
 
 
580
  await asyncio.sleep(wait_time)
581
- continue # Retry with the same key
582
 
583
  except Exception as e:
584
  last_exception = e
585
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
586
-
 
 
 
 
 
 
 
 
587
  if request and await request.is_disconnected():
588
- lib_logger.warning(f"Client disconnected. Aborting retries for credential ...{current_cred[-6:]}.")
 
 
589
  raise last_exception
590
 
591
  classified_error = classify_error(e)
592
- error_message = str(e).split('\n')[0]
593
- lib_logger.warning(f"Key ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key.")
 
 
594
  if classified_error.status_code == 429:
595
  cooldown_duration = classified_error.retry_after or 60
596
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
597
- lib_logger.warning(f"IP-based rate limit detected for {provider} from generic exception. Starting a {cooldown_duration}-second global cooldown.")
598
-
599
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
 
 
 
 
 
 
 
 
600
  # For these errors, we should not retry with other keys.
601
  raise last_exception
602
 
603
- await self.usage_manager.record_failure(current_cred, model, classified_error)
604
- break # Try next key for other errors
 
 
605
  finally:
606
  if key_acquired and current_cred:
607
  await self.usage_manager.release_key(current_cred, model)
@@ -609,269 +812,411 @@ class RotatingClient:
609
  if last_exception:
610
  # Log the final error but do not raise it, as per the new requirement.
611
  # The client should not see intermittent failures.
612
- lib_logger.error(f"Request failed after trying all keys or exceeding global timeout. Last error: {last_exception}")
613
-
 
 
614
  # Return None to indicate failure without propagating a disruptive exception.
615
  return None
616
 
617
- async def _streaming_acompletion_with_retry(self, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs) -> AsyncGenerator[str, None]:
 
 
 
 
 
618
  """A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
619
  model = kwargs.get("model")
620
- provider = model.split('/')[0]
621
-
622
  # Create a mutable copy of the keys and shuffle it.
623
  credentials_for_provider = list(self.all_credentials[provider])
624
  random.shuffle(credentials_for_provider)
625
-
626
  deadline = time.time() + self.global_timeout
627
  tried_creds = set()
628
  last_exception = None
629
  kwargs = self._convert_model_params(**kwargs)
630
-
631
  consecutive_quota_failures = 0
632
 
633
  try:
634
- while len(tried_creds) < len(credentials_for_provider) and time.time() < deadline:
 
 
 
635
  current_cred = None
636
  key_acquired = False
637
  try:
638
  if await self.cooldown_manager.is_cooling_down(provider):
639
- remaining_cooldown = await self.cooldown_manager.get_cooldown_remaining(provider)
 
 
640
  remaining_budget = deadline - time.time()
641
  if remaining_cooldown > remaining_budget:
642
- lib_logger.warning(f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early.")
 
 
643
  break
644
- lib_logger.warning(f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds.")
 
 
645
  await asyncio.sleep(remaining_cooldown)
646
 
647
- creds_to_try = [c for c in credentials_for_provider if c not in tried_creds]
 
 
648
  if not creds_to_try:
649
- lib_logger.warning(f"All credentials for provider {provider} have been tried. No more credentials to rotate to.")
 
 
650
  break
651
 
652
- lib_logger.info(f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}")
 
 
653
  current_cred = await self.usage_manager.acquire_key(
654
- available_keys=creds_to_try,
655
- model=model,
656
- deadline=deadline
657
  )
658
  key_acquired = True
659
  tried_creds.add(current_cred)
660
 
661
- litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
 
 
662
  if "reasoning_effort" in kwargs:
663
  litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"]
664
  # Check body first for custom_reasoning_budget
665
  if "custom_reasoning_budget" in kwargs:
666
- litellm_kwargs["custom_reasoning_budget"] = kwargs["custom_reasoning_budget"]
 
 
667
  else:
668
  custom_budget_header = None
669
- if request and hasattr(request, 'headers'):
670
- custom_budget_header = request.headers.get("custom_reasoning_budget")
 
 
671
 
672
  if custom_budget_header is not None:
673
- is_budget_enabled = custom_budget_header.lower() == 'true'
674
- litellm_kwargs["custom_reasoning_budget"] = is_budget_enabled
675
-
 
 
676
  # [NEW] Merge provider-specific params
677
  if provider in self.litellm_provider_params:
678
  litellm_kwargs["litellm_params"] = {
679
  **self.litellm_provider_params[provider],
680
- **litellm_kwargs.get("litellm_params", {})
681
  }
682
 
683
  provider_plugin = self._get_provider_instance(provider)
684
  if provider_plugin and provider_plugin.has_custom_logic():
685
- lib_logger.debug(f"Provider '{provider}' has custom logic. Delegating call.")
 
 
686
  litellm_kwargs["credential_identifier"] = current_cred
687
- litellm_kwargs["enable_request_logging"] = self.enable_request_logging
688
-
 
 
689
  for attempt in range(self.max_retries):
690
  try:
691
- lib_logger.info(f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
692
-
 
 
693
  if pre_request_callback:
694
  try:
695
- await pre_request_callback(request, litellm_kwargs)
 
 
696
  except Exception as e:
697
  if self.abort_on_callback_error:
698
- raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
 
 
699
  else:
700
- lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
701
-
702
- response = await provider_plugin.acompletion(self.http_client, **litellm_kwargs)
703
-
704
- lib_logger.info(f"Stream connection established for credential ...{current_cred[-6:]}. Processing response.")
 
 
 
 
 
 
705
 
706
  key_acquired = False
707
- stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
708
-
 
 
709
  async for chunk in stream_generator:
710
  yield chunk
711
  return
712
 
713
- except (StreamedAPIError, litellm.RateLimitError, httpx.HTTPStatusError) as e:
714
- if isinstance(e, httpx.HTTPStatusError) and e.response.status_code != 429:
 
 
 
 
 
 
 
715
  raise e
716
 
717
  last_exception = e
718
  # If the exception is our custom wrapper, unwrap the original error
719
- original_exc = getattr(e, 'data', e)
720
  classified_error = classify_error(original_exc)
721
- await self.usage_manager.record_failure(current_cred, model, classified_error)
722
- lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during custom provider stream. Rotating key.")
 
 
 
 
723
  break
724
 
725
- except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
 
 
 
 
726
  last_exception = e
727
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
728
  classified_error = classify_error(e)
729
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
730
 
731
  if attempt >= self.max_retries - 1:
732
- lib_logger.warning(f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key.")
 
 
733
  break
734
-
735
- wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
 
 
736
  remaining_budget = deadline - time.time()
737
  if wait_time > remaining_budget:
738
- lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
 
 
739
  break
740
-
741
- error_message = str(e).split('\n')[0]
742
- lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
 
 
743
  await asyncio.sleep(wait_time)
744
  continue
745
 
746
  except Exception as e:
747
  last_exception = e
748
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
749
  classified_error = classify_error(e)
750
- lib_logger.warning(f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
751
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
 
 
 
 
 
 
752
  raise last_exception
753
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
754
  break
755
-
756
  # If the inner loop breaks, it means the key failed and we need to rotate.
757
  # Continue to the next iteration of the outer while loop to pick a new key.
758
  continue
759
 
760
- else: # This is the standard API Key / litellm-handled provider logic
761
  is_oauth = provider in self.oauth_providers
762
- if is_oauth: # Standard OAuth provider (not custom)
763
  # ... (logic to set headers) ...
764
  pass
765
- else: # API Key
766
  litellm_kwargs["api_key"] = current_cred
767
 
768
  provider_instance = self._get_provider_instance(provider)
769
  if provider_instance:
770
  if "safety_settings" in litellm_kwargs:
771
- converted_settings = provider_instance.convert_safety_settings(litellm_kwargs["safety_settings"])
 
 
 
 
772
  if converted_settings is not None:
773
  litellm_kwargs["safety_settings"] = converted_settings
774
  else:
775
  del litellm_kwargs["safety_settings"]
776
-
777
  if provider == "gemini" and provider_instance:
778
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
779
  if provider == "nvidia_nim" and provider_instance:
780
- provider_instance.handle_thinking_parameter(litellm_kwargs, model)
 
 
781
 
782
  if "gemma-3" in model and "messages" in litellm_kwargs:
783
- litellm_kwargs["messages"] = [{"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"]]
784
-
 
 
 
 
 
785
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
786
 
787
  # If the provider is 'qwen_code', set the custom provider to 'qwen'
788
  # and strip the prefix from the model name for LiteLLM.
789
  if provider == "qwen_code":
790
  litellm_kwargs["custom_llm_provider"] = "qwen"
791
- litellm_kwargs["model"] = model.split('/', 1)[1]
792
 
793
  for attempt in range(self.max_retries):
794
  try:
795
- lib_logger.info(f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})")
796
-
 
 
797
  if pre_request_callback:
798
  try:
799
  await pre_request_callback(request, litellm_kwargs)
800
  except Exception as e:
801
  if self.abort_on_callback_error:
802
- raise PreRequestCallbackError(f"Pre-request callback failed: {e}") from e
 
 
803
  else:
804
- lib_logger.warning(f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}")
805
-
806
- #lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}")
 
 
807
  response = await litellm.acompletion(
808
  **litellm_kwargs,
809
- logger_fn=self._litellm_logger_callback
 
 
 
 
810
  )
811
-
812
- lib_logger.info(f"Stream connection established for credential ...{current_cred[-6:]}. Processing response.")
813
 
814
  key_acquired = False
815
- stream_generator = self._safe_streaming_wrapper(response, current_cred, model, request)
816
-
 
 
817
  async for chunk in stream_generator:
818
  yield chunk
819
  return
820
 
821
  except (StreamedAPIError, litellm.RateLimitError) as e:
822
  last_exception = e
823
-
824
  # This is the final, robust handler for streamed errors.
825
  error_payload = {}
826
  cleaned_str = None
827
  # The actual exception might be wrapped in our StreamedAPIError.
828
- original_exc = getattr(e, 'data', e)
829
  classified_error = classify_error(original_exc)
830
 
831
  try:
832
  # The full error JSON is in the string representation of the exception.
833
- json_str_match = re.search(r'(\{.*\})', str(original_exc), re.DOTALL)
 
 
834
  if json_str_match:
835
  # The string may contain byte-escaped characters (e.g., \\n).
836
- cleaned_str = codecs.decode(json_str_match.group(1), 'unicode_escape')
 
 
837
  error_payload = json.loads(cleaned_str)
838
  except (json.JSONDecodeError, TypeError):
839
- lib_logger.warning("Could not parse JSON details from streamed error exception.")
 
 
840
  error_payload = {}
841
-
842
  # Now, log the failure with the extracted raw response.
843
  log_failure(
844
  api_key=current_cred,
845
  model=model,
846
  attempt=attempt + 1,
847
  error=e,
848
- request_headers=dict(request.headers) if request else {},
849
- raw_response_text=cleaned_str
 
 
850
  )
851
 
852
  error_details = error_payload.get("error", {})
853
  error_status = error_details.get("status", "")
854
  # Fallback to the full string if parsing fails.
855
- error_message_text = error_details.get("message", str(original_exc))
 
 
856
 
857
- if "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower():
 
 
 
858
  consecutive_quota_failures += 1
859
- lib_logger.warning(f"Credential ...{current_cred[-6:]} hit a quota limit. This is consecutive failure #{consecutive_quota_failures} for this request.")
 
 
860
 
861
  quota_value = "N/A"
862
  quota_id = "N/A"
863
- if "details" in error_details and isinstance(error_details.get("details"), list):
 
 
864
  for detail in error_details["details"]:
865
  if isinstance(detail.get("violations"), list):
866
  for violation in detail["violations"]:
867
  if "quotaValue" in violation:
868
- quota_value = violation["quotaValue"]
 
 
869
  if "quotaId" in violation:
870
  quota_id = violation["quotaId"]
871
- if quota_value != "N/A" and quota_id != "N/A":
 
 
 
872
  break
873
-
874
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
875
 
876
  if consecutive_quota_failures >= 3:
877
  console_log_message = (
@@ -884,100 +1229,176 @@ class RotatingClient:
884
  f"Last Error Message: '{error_message_text}'. Limit: {quota_value} (Quota ID: {quota_id})."
885
  )
886
  lib_logger.error(console_log_message)
887
-
888
  yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n"
889
  yield "data: [DONE]\n\n"
890
  return
891
-
892
  else:
893
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
894
- lib_logger.warning(f"Quota error on credential ...{current_cred[-6:]} (failure {consecutive_quota_failures}/3). Rotating key silently.")
 
 
895
  break
896
-
897
  else:
898
  consecutive_quota_failures = 0
899
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
900
- lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during stream. Rotating key silently.")
901
-
902
- if classified_error.error_type == 'rate_limit' and classified_error.status_code == 429:
903
- cooldown_duration = classified_error.retry_after or 60
904
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
905
- lib_logger.warning(f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown.")
906
-
907
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
 
 
 
 
 
 
 
 
 
 
 
908
  break
909
 
910
- except (APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError) as e:
 
 
 
 
911
  consecutive_quota_failures = 0
912
  last_exception = e
913
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
914
  classified_error = classify_error(e)
915
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
916
 
917
  if attempt >= self.max_retries - 1:
918
- lib_logger.warning(f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key silently.")
 
 
919
  # [MODIFIED] Do not yield to the client here.
920
  break
921
-
922
- wait_time = classified_error.retry_after or (1 * (2 ** attempt)) + random.uniform(0, 1)
 
 
923
  remaining_budget = deadline - time.time()
924
  if wait_time > remaining_budget:
925
- lib_logger.warning(f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early.")
 
 
926
  break
927
-
928
- error_message = str(e).split('\n')[0]
929
- lib_logger.warning(f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s.")
 
 
930
  await asyncio.sleep(wait_time)
931
  continue
932
 
933
  except Exception as e:
934
  consecutive_quota_failures = 0
935
  last_exception = e
936
- log_failure(api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {})
 
 
 
 
 
 
 
 
937
  classified_error = classify_error(e)
938
 
939
- lib_logger.warning(f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key.")
 
 
940
 
941
  if classified_error.status_code == 429:
942
  cooldown_duration = classified_error.retry_after or 60
943
- await self.cooldown_manager.start_cooldown(provider, cooldown_duration)
944
- lib_logger.warning(f"IP-based rate limit detected for {provider} from generic stream exception. Starting a {cooldown_duration}-second global cooldown.")
945
-
946
- if classified_error.error_type in ['invalid_request', 'context_window_exceeded', 'authentication']:
 
 
 
 
 
 
 
 
947
  raise last_exception
948
-
949
  # [MODIFIED] Do not yield to the client here.
950
- await self.usage_manager.record_failure(current_cred, model, classified_error)
 
 
951
  break
952
 
953
  finally:
954
  if key_acquired and current_cred:
955
  await self.usage_manager.release_key(current_cred, model)
956
-
957
  final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
958
  if last_exception:
959
  final_error_message = f"Failed to complete the streaming request. Last error: {str(last_exception)}"
960
- lib_logger.error(f"Streaming request failed after trying all keys. Last error: {last_exception}")
 
 
961
  else:
962
  lib_logger.error(final_error_message)
963
 
964
- error_data = {"error": {"message": final_error_message, "type": "proxy_error"}}
 
 
965
  yield f"data: {json.dumps(error_data)}\n\n"
966
  yield "data: [DONE]\n\n"
967
 
968
  except NoAvailableKeysError as e:
969
- lib_logger.error(f"A streaming request failed because no keys were available within the time budget: {e}")
 
 
970
  error_data = {"error": {"message": str(e), "type": "proxy_busy"}}
971
  yield f"data: {json.dumps(error_data)}\n\n"
972
  yield "data: [DONE]\n\n"
973
  except Exception as e:
974
  # This will now only catch fatal errors that should be raised, like invalid requests.
975
- lib_logger.error(f"An unhandled exception occurred in streaming retry logic: {e}", exc_info=True)
976
- error_data = {"error": {"message": f"An unexpected error occurred: {str(e)}", "type": "proxy_internal_error"}}
 
 
 
 
 
 
 
 
977
  yield f"data: {json.dumps(error_data)}\n\n"
978
  yield "data: [DONE]\n\n"
979
 
980
- def acompletion(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Union[Any, AsyncGenerator[str, None]]:
 
 
 
 
 
981
  """
982
  Dispatcher for completion requests.
983
 
@@ -996,11 +1417,23 @@ class RotatingClient:
996
  kwargs["stream_options"] = {}
997
  if "include_usage" not in kwargs["stream_options"]:
998
  kwargs["stream_options"]["include_usage"] = True
999
- return self._streaming_acompletion_with_retry(request=request, pre_request_callback=pre_request_callback, **kwargs)
 
 
1000
  else:
1001
- return self._execute_with_retry(litellm.acompletion, request=request, pre_request_callback=pre_request_callback, **kwargs)
1002
-
1003
- def aembedding(self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs) -> Any:
 
 
 
 
 
 
 
 
 
 
1004
  """
1005
  Executes an embedding request with retry logic.
1006
 
@@ -1014,7 +1447,12 @@ class RotatingClient:
1014
  Returns:
1015
  The embedding response object, or None if all retries fail.
1016
  """
1017
- return self._execute_with_retry(litellm.aembedding, request=request, pre_request_callback=pre_request_callback, **kwargs)
 
 
 
 
 
1018
 
1019
  def token_count(self, **kwargs) -> int:
1020
  """Calculates the number of tokens for a given text or list of messages."""
@@ -1057,10 +1495,20 @@ class RotatingClient:
1057
  for credential in shuffled_credentials:
1058
  try:
1059
  # Display last 6 chars for API keys, or the filename for OAuth paths
1060
- cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
1061
- lib_logger.debug(f"Attempting to get models for {provider} with credential ...{cred_display}")
1062
- models = await provider_instance.get_models(credential, self.http_client)
1063
- lib_logger.info(f"Got {len(models)} models for provider: {provider}")
 
 
 
 
 
 
 
 
 
 
1064
 
1065
  # Whitelist and blacklist logic
1066
  final_models = []
@@ -1076,23 +1524,35 @@ class RotatingClient:
1076
  final_models.append(m)
1077
 
1078
  if len(final_models) != len(models):
1079
- lib_logger.info(f"Filtered out {len(models) - len(final_models)} models for provider {provider}.")
 
 
1080
 
1081
  self._model_list_cache[provider] = final_models
1082
  return final_models
1083
  except Exception as e:
1084
  classified_error = classify_error(e)
1085
- cred_display = credential[-6:] if not os.path.isfile(credential) else os.path.basename(credential)
1086
- lib_logger.debug(f"Failed to get models for provider {provider} with credential ...{cred_display}: {classified_error.error_type}. Trying next credential.")
1087
- continue # Try the next credential
 
 
 
 
 
 
1088
 
1089
- lib_logger.error(f"Failed to get models for provider {provider} after trying all credentials.")
 
 
1090
  return []
1091
 
1092
- async def get_all_available_models(self, grouped: bool = True) -> Union[Dict[str, List[str]], List[str]]:
 
 
1093
  """Returns a list of all available models, either grouped by provider or as a flat list."""
1094
  lib_logger.info("Getting all available models...")
1095
-
1096
  all_providers = list(self.all_credentials.keys())
1097
  tasks = [self.get_available_models(provider) for provider in all_providers]
1098
  results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -1100,11 +1560,13 @@ class RotatingClient:
1100
  all_provider_models = {}
1101
  for provider, result in zip(all_providers, results):
1102
  if isinstance(result, Exception):
1103
- lib_logger.error(f"Failed to get models for provider {provider}: {result}")
 
 
1104
  all_provider_models[provider] = []
1105
  else:
1106
  all_provider_models[provider] = result
1107
-
1108
  lib_logger.info("Finished getting all available models.")
1109
  if grouped:
1110
  return all_provider_models
 
12
  import logging
13
  from typing import List, Dict, Any, AsyncGenerator, Optional, Union
14
 
15
+ lib_logger = logging.getLogger("rotator_library")
16
  # Ensure the logger is configured to propagate to the root logger
17
  # which is set up in main.py. This allows the main app to control
18
  # log levels and handlers centrally.
 
20
 
21
  from .usage_manager import UsageManager
22
  from .failure_logger import log_failure
23
+ from .error_handler import (
24
+ PreRequestCallbackError,
25
+ classify_error,
26
+ AllProviders,
27
+ NoAvailableKeysError,
28
+ )
29
  from .providers import PROVIDER_PLUGINS
30
+ from .providers.openai_compatible_provider import OpenAICompatibleProvider
31
  from .request_sanitizer import sanitize_request_payload
32
  from .cooldown_manager import CooldownManager
33
  from .credential_manager import CredentialManager
34
  from .background_refresher import BackgroundRefresher
35
 
36
+
37
  class StreamedAPIError(Exception):
38
  """Custom exception to signal an API error received over a stream."""
39
+
40
  def __init__(self, message, data=None):
41
  super().__init__(message)
42
  self.data = data
43
 
44
+
45
  class RotatingClient:
46
  """
47
  A client that intelligently rotates and retries API keys using LiteLLM,
48
  with support for both streaming and non-streaming responses.
49
  """
50
+
51
  def __init__(
52
  self,
53
  api_keys: Optional[Dict[str, List[str]]] = None,
 
60
  litellm_provider_params: Optional[Dict[str, Any]] = None,
61
  ignore_models: Optional[Dict[str, List[str]]] = None,
62
  whitelist_models: Optional[Dict[str, List[str]]] = None,
63
+ enable_request_logging: bool = False,
64
  ):
65
  os.environ["LITELLM_LOG"] = "ERROR"
66
  litellm.set_verbose = False
 
81
 
82
  # Filter out providers with empty lists of credentials to ensure validity
83
  api_keys = {provider: keys for provider, keys in api_keys.items() if keys}
84
+ oauth_credentials = {
85
+ provider: paths for provider, paths in oauth_credentials.items() if paths
86
+ }
87
 
88
  if not api_keys and not oauth_credentials:
89
+ raise ValueError(
90
+ "No valid credentials provided. Either 'api_keys' or 'oauth_credentials' must be provided and non-empty."
91
+ )
92
 
93
  self.api_keys = api_keys
94
  self.credential_manager = CredentialManager(oauth_credentials)
95
  self.oauth_credentials = self.credential_manager.discover_and_prepare()
96
  self.background_refresher = BackgroundRefresher(self)
97
  self.oauth_providers = set(self.oauth_credentials.keys())
98
+
99
  all_credentials = {}
100
  for provider, keys in api_keys.items():
101
  all_credentials.setdefault(provider, []).extend(keys)
102
  for provider, paths in self.oauth_credentials.items():
103
  all_credentials.setdefault(provider, []).extend(paths)
104
  self.all_credentials = all_credentials
105
+
106
  self.max_retries = max_retries
107
  self.global_timeout = global_timeout
108
  self.abort_on_callback_error = abort_on_callback_error
 
123
  Checks if a model should be ignored based on the ignore list.
124
  Supports exact and partial matching for both full model IDs and model names.
125
  """
126
+ model_provider = model_id.split("/")[0]
127
  if model_provider not in self.ignore_models:
128
  return False
129
 
130
  ignore_list = self.ignore_models[model_provider]
131
+ if ignore_list == ["*"]:
132
  return True
133
+
134
  try:
135
  # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
136
+ provider_model_name = model_id.split("/", 1)[1]
137
  except IndexError:
138
  provider_model_name = model_id
139
 
140
  for ignored_pattern in ignore_list:
141
+ if ignored_pattern.endswith("*"):
142
  match_pattern = ignored_pattern[:-1]
143
  # Match wildcard against the provider's model name
144
  if provider_model_name.startswith(match_pattern):
145
  return True
146
  else:
147
  # Exact match against the full proxy ID OR the provider's model name
148
+ if (
149
+ model_id == ignored_pattern
150
+ or provider_model_name == ignored_pattern
151
+ ):
152
  return True
153
  return False
154
 
 
157
  Checks if a model is explicitly whitelisted.
158
  Supports exact and partial matching for both full model IDs and model names.
159
  """
160
+ model_provider = model_id.split("/")[0]
161
  if model_provider not in self.whitelist_models:
162
  return False
163
 
164
  whitelist = self.whitelist_models[model_provider]
165
  for whitelisted_pattern in whitelist:
166
+ if whitelisted_pattern == "*":
167
  return True
168
+
169
  try:
170
  # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b")
171
+ provider_model_name = model_id.split("/", 1)[1]
172
  except IndexError:
173
  provider_model_name = model_id
174
 
175
+ if whitelisted_pattern.endswith("*"):
176
  match_pattern = whitelisted_pattern[:-1]
177
  # Match wildcard against the provider's model name
178
  if provider_model_name.startswith(match_pattern):
179
  return True
180
  else:
181
  # Exact match against the full proxy ID OR the provider's model name
182
+ if (
183
+ model_id == whitelisted_pattern
184
+ or provider_model_name == whitelisted_pattern
185
+ ):
186
  return True
187
  return False
188
 
 
196
 
197
  # Keys to remove at any level of the dictionary
198
  keys_to_pop = [
199
+ "messages",
200
+ "input",
201
+ "response",
202
+ "data",
203
+ "api_key",
204
+ "api_base",
205
+ "original_response",
206
+ "additional_args",
207
  ]
208
+
209
  # Keys that might contain nested dictionaries to clean
210
  nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"]
211
 
 
219
  # Remove sensitive/large keys
220
  for key in keys_to_pop:
221
  data_dict.pop(key, None)
222
+
223
  # Recursively clean nested dictionaries
224
  for key in nested_keys:
225
  if key in data_dict and isinstance(data_dict[key], dict):
226
  clean_recursively(data_dict[key])
227
+
228
  # Also iterate through all values to find any other nested dicts
229
  for key, value in list(data_dict.items()):
230
  if isinstance(value, dict):
 
243
  log_event_type = log_data.get("log_event_type")
244
  if log_event_type in ["pre_api_call", "post_api_call"]:
245
  return # Skip these verbose logs entirely
246
+
247
  # For successful calls or pre-call logs, a simple debug message is enough.
248
  if not log_data.get("exception"):
249
  sanitized_log = self._sanitize_litellm_log(log_data)
250
  # We log it at the DEBUG level to ensure it goes to the debug file
251
+ # and not the console, based on the main.py configuration.
252
  lib_logger.debug(f"LiteLLM Log: {sanitized_log}")
253
  return
254
 
255
  # For failures, extract key info to make debug logs more readable.
256
  model = log_data.get("model", "N/A")
257
  call_id = log_data.get("litellm_call_id", "N/A")
258
+ error_info = log_data.get("standard_logging_object", {}).get(
259
+ "error_information", {}
260
+ )
261
  error_class = error_info.get("error_class", "UnknownError")
262
+ error_message = error_info.get(
263
+ "error_message", str(log_data.get("exception", ""))
264
+ )
265
+ error_message = " ".join(error_message.split()) # Sanitize
266
 
267
  lib_logger.debug(
268
  f"LiteLLM Callback Handled Error: Model={model} | "
 
277
 
278
  async def close(self):
279
  """Close the HTTP client to prevent resource leaks."""
280
+ if hasattr(self, "http_client") and self.http_client:
281
  await self.http_client.aclose()
282
 
283
  def _convert_model_params(self, **kwargs) -> Dict[str, Any]:
 
290
  if not model:
291
  return kwargs
292
 
293
+ provider = model.split("/")[0]
294
  if provider == "chutes":
295
  kwargs["model"] = f"openai/{model.split('/', 1)[1]}"
296
  kwargs["api_base"] = "https://llm.chutes.ai/v1"
297
+
298
  return kwargs
299
 
300
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
301
  return self.oauth_credentials
302
 
303
+ def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool:
304
+ """Checks if a provider is a custom OpenAI-compatible provider."""
305
+ import os
306
+
307
+ # Check if the provider has an API_BASE environment variable
308
+ api_base_env = f"{provider_name.upper()}_API_BASE"
309
+ return os.getenv(api_base_env) is not None
310
+
311
  def _get_provider_instance(self, provider_name: str):
312
  """Lazily initializes and returns a provider instance."""
313
  if provider_name not in self._provider_instances:
314
  if provider_name in self._provider_plugins:
315
+ self._provider_instances[provider_name] = self._provider_plugins[
316
+ provider_name
317
+ ]()
318
+ elif self._is_custom_openai_compatible_provider(provider_name):
319
+ # Create a generic OpenAI-compatible provider for custom providers
320
+ try:
321
+ self._provider_instances[provider_name] = OpenAICompatibleProvider(
322
+ provider_name
323
+ )
324
+ except ValueError:
325
+ # If the provider doesn't have the required environment variables, treat it as a standard provider
326
+ return None
327
  else:
328
  return None
329
  return self._provider_instances[provider_name]
330
 
331
+ async def _safe_streaming_wrapper(
332
+ self, stream: Any, key: str, model: str, request: Optional[Any] = None
333
+ ) -> AsyncGenerator[Any, None]:
334
  """
335
  A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
336
  and distinguishes between content and streamed errors.
 
343
  try:
344
  while True:
345
  if request and await request.is_disconnected():
346
+ lib_logger.info(
347
+ f"Client disconnected. Aborting stream for credential ...{key[-6:]}."
348
+ )
349
  # Do not yield [DONE] because the client is gone.
350
  # The 'finally' block will handle key release.
351
  break
 
355
  if json_buffer:
356
  # If we are about to discard a buffer, it means data was likely lost.
357
  # Log this as a warning to make it visible.
358
+ lib_logger.warning(
359
+ f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}"
360
+ )
361
  json_buffer = ""
362
+
363
  yield f"data: {json.dumps(chunk.dict())}\n\n"
364
 
365
+ if hasattr(chunk, "usage") and chunk.usage:
366
+ last_usage = (
367
+ chunk.usage
368
+ ) # Overwrite with the latest (cumulative)
369
 
370
  except StopAsyncIteration:
371
  stream_completed = True
372
  if json_buffer:
373
+ lib_logger.info(
374
+ f"Stream ended with incomplete data in buffer: {json_buffer}"
375
+ )
376
  if last_usage:
377
  # Create a dummy ModelResponse for recording (only usage matters)
378
  dummy_response = litellm.ModelResponse(usage=last_usage)
379
+ await self.usage_manager.record_success(
380
+ key, model, dummy_response
381
+ )
382
  else:
383
  # If no usage seen (rare), record success without tokens/cost
384
  await self.usage_manager.record_success(key, model)
385
  break
386
 
387
+ except (
388
+ litellm.RateLimitError,
389
+ litellm.ServiceUnavailableError,
390
+ litellm.InternalServerError,
391
+ APIConnectionError,
392
+ ) as e:
393
  # This is a critical, typed error from litellm that signals a key failure.
394
  # We do not try to parse it here. We wrap it and raise it immediately
395
  # for the outer retry loop to handle.
396
+ lib_logger.warning(
397
+ f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation."
398
+ )
399
  raise StreamedAPIError("Provider error received in stream", data=e)
400
 
401
  except Exception as e:
 
406
  match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL)
407
  if match:
408
  # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it.
409
+ raw_chunk = codecs.decode(match.group(1), "unicode_escape")
410
  else:
411
  # Fallback for other potential error formats that use "Received chunk:".
412
+ chunk_from_split = (
413
+ str(e).split("Received chunk:")[-1].strip()
414
+ )
415
+ if chunk_from_split != str(
416
+ e
417
+ ): # Ensure the split actually did something
418
  raw_chunk = chunk_from_split
419
+
420
  if not raw_chunk:
421
  # If we could not extract a valid chunk, we cannot proceed with reassembly.
422
  # This indicates a different, unexpected error type. Re-raise it.
 
425
  # Append the clean chunk to the buffer and try to parse.
426
  json_buffer += raw_chunk
427
  parsed_data = json.loads(json_buffer)
428
+
429
  # If parsing succeeds, we have the complete object.
430
+ lib_logger.info(
431
+ f"Successfully reassembled JSON from stream: {json_buffer}"
432
+ )
433
+
434
  # Wrap the complete error object and raise it. The outer function will decide how to handle it.
435
+ raise StreamedAPIError(
436
+ "Provider error received in stream", data=parsed_data
437
+ )
438
 
439
  except json.JSONDecodeError:
440
  # This is the expected outcome if the JSON in the buffer is not yet complete.
441
+ lib_logger.info(
442
+ f"Buffer still incomplete. Waiting for more chunks: {json_buffer}"
443
+ )
444
+ continue # Continue to the next loop to get the next chunk.
445
  except StreamedAPIError:
446
  # Re-raise to be caught by the outer retry handler.
447
  raise
448
  except Exception as buffer_exc:
449
  # If the error was not a JSONDecodeError, it's an unexpected internal error.
450
+ lib_logger.error(
451
+ f"Error during stream buffering logic: {buffer_exc}. Discarding buffer."
452
+ )
453
+ json_buffer = (
454
+ "" # Clear the corrupted buffer to prevent further issues.
455
+ )
456
  raise buffer_exc
457
+
458
  except StreamedAPIError:
459
  # This is caught by the acompletion retry logic.
460
  # We re-raise it to ensure it's not caught by the generic 'except Exception'.
 
463
  except Exception as e:
464
  # Catch any other unexpected errors during streaming.
465
  lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}")
466
+ lib_logger.error(
467
+ f"An unexpected error occurred during the stream for credential ...{key[-6:]}: {e}"
468
+ )
469
  # We still need to raise it so the client knows something went wrong.
470
  raise
471
 
 
473
  # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.).
474
  # The primary goal is to ensure usage is always logged internally.
475
  await self.usage_manager.release_key(key, model)
476
+ lib_logger.info(
477
+ f"STREAM FINISHED and lock released for credential ...{key[-6:]}."
478
+ )
479
+
480
  # Only send [DONE] if the stream completed naturally and the client is still there.
481
  # This prevents sending [DONE] to a disconnected client or after an error.
482
+ if stream_completed and (
483
+ not request or not await request.is_disconnected()
484
+ ):
485
  yield "data: [DONE]\n\n"
486
 
487
+ async def _execute_with_retry(
488
+ self,
489
+ api_call: callable,
490
+ request: Optional[Any],
491
+ pre_request_callback: Optional[callable] = None,
492
+ **kwargs,
493
+ ) -> Any:
494
  """A generic retry mechanism for non-streaming API calls."""
495
  model = kwargs.get("model")
496
  if not model:
497
  raise ValueError("'model' is a required parameter.")
498
 
499
+ provider = model.split("/")[0]
500
  if provider not in self.all_credentials:
501
+ raise ValueError(
502
+ f"No API keys or OAuth credentials configured for provider: {provider}"
503
+ )
504
 
505
  # Establish a global deadline for the entire request lifecycle.
506
  deadline = time.time() + self.global_timeout
507
+
508
  # Create a mutable copy of the keys and shuffle it to ensure
509
  # that the key selection is randomized, which is crucial when
510
  # multiple keys have the same usage stats.
511
  credentials_for_provider = list(self.all_credentials[provider])
512
  random.shuffle(credentials_for_provider)
513
+
514
  tried_creds = set()
515
  last_exception = None
516
  kwargs = self._convert_model_params(**kwargs)
517
 
518
  # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
519
+ while (
520
+ len(tried_creds) < len(credentials_for_provider) and time.time() < deadline
521
+ ):
522
  current_cred = None
523
  key_acquired = False
524
  try:
525
  # Check for a provider-wide cooldown first.
526
  if await self.cooldown_manager.is_cooling_down(provider):
527
+ remaining_cooldown = (
528
+ await self.cooldown_manager.get_cooldown_remaining(provider)
529
+ )
530
  remaining_budget = deadline - time.time()
531
+
532
  # If the cooldown is longer than the remaining time budget, fail fast.
533
  if remaining_cooldown > remaining_budget:
534
+ lib_logger.warning(
535
+ f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early."
536
+ )
537
  break
538
 
539
+ lib_logger.warning(
540
+ f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds."
541
+ )
542
  await asyncio.sleep(remaining_cooldown)
543
 
544
+ creds_to_try = [
545
+ c for c in credentials_for_provider if c not in tried_creds
546
+ ]
547
  if not creds_to_try:
548
  break
549
 
550
+ lib_logger.info(
551
+ f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{len(credentials_for_provider)}"
552
+ )
553
  current_cred = await self.usage_manager.acquire_key(
554
+ available_keys=creds_to_try, model=model, deadline=deadline
 
 
555
  )
556
  key_acquired = True
557
  tried_creds.add(current_cred)
558
 
559
  litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy())
560
+
561
  # [NEW] Merge provider-specific params
562
  if provider in self.litellm_provider_params:
563
  litellm_kwargs["litellm_params"] = {
564
  **self.litellm_provider_params[provider],
565
+ **litellm_kwargs.get("litellm_params", {}),
566
  }
567
+
568
  provider_plugin = self._get_provider_instance(provider)
569
  if provider_plugin and provider_plugin.has_custom_logic():
570
+ lib_logger.debug(
571
+ f"Provider '{provider}' has custom logic. Delegating call."
572
+ )
573
  litellm_kwargs["credential_identifier"] = current_cred
574
+ litellm_kwargs["enable_request_logging"] = (
575
+ self.enable_request_logging
576
+ )
577
 
578
  # Check body first for custom_reasoning_budget
579
  if "custom_reasoning_budget" in kwargs:
580
+ litellm_kwargs["custom_reasoning_budget"] = kwargs[
581
+ "custom_reasoning_budget"
582
+ ]
583
  else:
584
  custom_budget_header = None
585
+ if request and hasattr(request, "headers"):
586
+ custom_budget_header = request.headers.get(
587
+ "custom_reasoning_budget"
588
+ )
589
 
590
  if custom_budget_header is not None:
591
+ is_budget_enabled = custom_budget_header.lower() == "true"
592
+ litellm_kwargs["custom_reasoning_budget"] = (
593
+ is_budget_enabled
594
+ )
595
+
596
  # The plugin handles the entire call, including retries on 401, etc.
597
  # The main retry loop here is for key rotation on other errors.
598
+ response = await provider_plugin.acompletion(
599
+ self.http_client, **litellm_kwargs
600
+ )
601
+
602
  # For non-streaming, success is immediate, and this function only handles non-streaming.
603
+ await self.usage_manager.record_success(
604
+ current_cred, model, response
605
+ )
606
  await self.usage_manager.release_key(current_cred, model)
607
  key_acquired = False
608
  return response
609
 
610
+ else: # This is the standard API Key / litellm-handled provider logic
611
  is_oauth = provider in self.oauth_providers
612
+ if is_oauth: # Standard OAuth provider (not custom)
613
  # ... (logic to set headers) ...
614
  pass
615
+ else: # API Key
616
  litellm_kwargs["api_key"] = current_cred
617
+
618
  provider_instance = self._get_provider_instance(provider)
619
  if provider_instance:
620
  if "safety_settings" in litellm_kwargs:
621
+ converted_settings = (
622
+ provider_instance.convert_safety_settings(
623
+ litellm_kwargs["safety_settings"]
624
+ )
625
+ )
626
  if converted_settings is not None:
627
  litellm_kwargs["safety_settings"] = converted_settings
628
  else:
629
  del litellm_kwargs["safety_settings"]
630
+
631
  if provider == "gemini" and provider_instance:
632
+ provider_instance.handle_thinking_parameter(
633
+ litellm_kwargs, model
634
+ )
635
  if provider == "nvidia_nim" and provider_instance:
636
+ provider_instance.handle_thinking_parameter(
637
+ litellm_kwargs, model
638
+ )
639
 
640
  if "gemma-3" in model and "messages" in litellm_kwargs:
641
+ litellm_kwargs["messages"] = [
642
+ {"role": "user", "content": m["content"]}
643
+ if m.get("role") == "system"
644
+ else m
645
+ for m in litellm_kwargs["messages"]
646
+ ]
647
+
648
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
649
 
650
  for attempt in range(self.max_retries):
651
  try:
652
+ lib_logger.info(
653
+ f"Attempting call with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})"
654
+ )
655
+
656
  if pre_request_callback:
657
  try:
658
  await pre_request_callback(request, litellm_kwargs)
659
  except Exception as e:
660
  if self.abort_on_callback_error:
661
+ raise PreRequestCallbackError(
662
+ f"Pre-request callback failed: {e}"
663
+ ) from e
664
  else:
665
+ lib_logger.warning(
666
+ f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
667
+ )
668
+
669
  response = await api_call(
670
  **litellm_kwargs,
671
+ logger_fn=self._litellm_logger_callback,
672
+ )
673
+
674
+ await self.usage_manager.record_success(
675
+ current_cred, model, response
676
  )
 
 
677
  await self.usage_manager.release_key(current_cred, model)
678
  key_acquired = False
679
  return response
680
 
681
  except litellm.RateLimitError as e:
682
  last_exception = e
683
+ log_failure(
684
+ api_key=current_cred,
685
+ model=model,
686
+ attempt=attempt + 1,
687
+ error=e,
688
+ request_headers=dict(request.headers)
689
+ if request
690
+ else {},
691
+ )
692
  classified_error = classify_error(e)
693
+
694
  # Extract a clean error message for the user-facing log
695
+ error_message = str(e).split("\n")[0]
696
+ lib_logger.info(
697
+ f"Key ...{current_cred[-6:]} hit rate limit for model {model}. Reason: '{error_message}'. Rotating key."
698
+ )
699
 
700
  if classified_error.status_code == 429:
701
  cooldown_duration = classified_error.retry_after or 60
702
+ await self.cooldown_manager.start_cooldown(
703
+ provider, cooldown_duration
704
+ )
705
+ lib_logger.warning(
706
+ f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown."
707
+ )
708
+
709
+ await self.usage_manager.record_failure(
710
+ current_cred, model, classified_error
711
+ )
712
+ lib_logger.warning(
713
+ f"Key ...{current_cred[-6:]} encountered a rate limit. Trying next key."
714
+ )
715
+ break # Move to the next key
716
 
717
+ except (
718
+ APIConnectionError,
719
+ litellm.InternalServerError,
720
+ litellm.ServiceUnavailableError,
721
+ ) as e:
722
  last_exception = e
723
+ log_failure(
724
+ api_key=current_cred,
725
+ model=model,
726
+ attempt=attempt + 1,
727
+ error=e,
728
+ request_headers=dict(request.headers)
729
+ if request
730
+ else {},
731
+ )
732
  classified_error = classify_error(e)
733
+ await self.usage_manager.record_failure(
734
+ current_cred, model, classified_error
735
+ )
736
 
737
  if attempt >= self.max_retries - 1:
738
+ error_message = str(e).split("\n")[0]
739
+ lib_logger.warning(
740
+ f"Key ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Reason: '{error_message}'. Rotating key."
741
+ )
742
+ break # Move to the next key
743
+
744
  # For temporary errors, wait before retrying with the same key.
745
+ wait_time = classified_error.retry_after or (
746
+ 1 * (2**attempt)
747
+ ) + random.uniform(0, 1)
748
  remaining_budget = deadline - time.time()
749
+
750
  # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
751
  if wait_time > remaining_budget:
752
+ lib_logger.warning(
753
+ f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early."
754
+ )
755
  break
756
 
757
+ error_message = str(e).split("\n")[0]
758
+ lib_logger.warning(
759
+ f"Key ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s."
760
+ )
761
  await asyncio.sleep(wait_time)
762
+ continue # Retry with the same key
763
 
764
  except Exception as e:
765
  last_exception = e
766
+ log_failure(
767
+ api_key=current_cred,
768
+ model=model,
769
+ attempt=attempt + 1,
770
+ error=e,
771
+ request_headers=dict(request.headers)
772
+ if request
773
+ else {},
774
+ )
775
+
776
  if request and await request.is_disconnected():
777
+ lib_logger.warning(
778
+ f"Client disconnected. Aborting retries for credential ...{current_cred[-6:]}."
779
+ )
780
  raise last_exception
781
 
782
  classified_error = classify_error(e)
783
+ error_message = str(e).split("\n")[0]
784
+ lib_logger.warning(
785
+ f"Key ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message}. Rotating key."
786
+ )
787
  if classified_error.status_code == 429:
788
  cooldown_duration = classified_error.retry_after or 60
789
+ await self.cooldown_manager.start_cooldown(
790
+ provider, cooldown_duration
791
+ )
792
+ lib_logger.warning(
793
+ f"IP-based rate limit detected for {provider} from generic exception. Starting a {cooldown_duration}-second global cooldown."
794
+ )
795
+
796
+ if classified_error.error_type in [
797
+ "invalid_request",
798
+ "context_window_exceeded",
799
+ "authentication",
800
+ ]:
801
  # For these errors, we should not retry with other keys.
802
  raise last_exception
803
 
804
+ await self.usage_manager.record_failure(
805
+ current_cred, model, classified_error
806
+ )
807
+ break # Try next key for other errors
808
  finally:
809
  if key_acquired and current_cred:
810
  await self.usage_manager.release_key(current_cred, model)
 
812
  if last_exception:
813
  # Log the final error but do not raise it, as per the new requirement.
814
  # The client should not see intermittent failures.
815
+ lib_logger.error(
816
+ f"Request failed after trying all keys or exceeding global timeout. Last error: {last_exception}"
817
+ )
818
+
819
  # Return None to indicate failure without propagating a disruptive exception.
820
  return None
821
 
822
+ async def _streaming_acompletion_with_retry(
823
+ self,
824
+ request: Optional[Any],
825
+ pre_request_callback: Optional[callable] = None,
826
+ **kwargs,
827
+ ) -> AsyncGenerator[str, None]:
828
  """A dedicated generator for retrying streaming completions with full request preparation and per-key retries."""
829
  model = kwargs.get("model")
830
+ provider = model.split("/")[0]
831
+
832
  # Create a mutable copy of the keys and shuffle it.
833
  credentials_for_provider = list(self.all_credentials[provider])
834
  random.shuffle(credentials_for_provider)
835
+
836
  deadline = time.time() + self.global_timeout
837
  tried_creds = set()
838
  last_exception = None
839
  kwargs = self._convert_model_params(**kwargs)
840
+
841
  consecutive_quota_failures = 0
842
 
843
  try:
844
+ while (
845
+ len(tried_creds) < len(credentials_for_provider)
846
+ and time.time() < deadline
847
+ ):
848
  current_cred = None
849
  key_acquired = False
850
  try:
851
  if await self.cooldown_manager.is_cooling_down(provider):
852
+ remaining_cooldown = (
853
+ await self.cooldown_manager.get_cooldown_remaining(provider)
854
+ )
855
  remaining_budget = deadline - time.time()
856
  if remaining_cooldown > remaining_budget:
857
+ lib_logger.warning(
858
+ f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early."
859
+ )
860
  break
861
+ lib_logger.warning(
862
+ f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds."
863
+ )
864
  await asyncio.sleep(remaining_cooldown)
865
 
866
+ creds_to_try = [
867
+ c for c in credentials_for_provider if c not in tried_creds
868
+ ]
869
  if not creds_to_try:
870
+ lib_logger.warning(
871
+ f"All credentials for provider {provider} have been tried. No more credentials to rotate to."
872
+ )
873
  break
874
 
875
+ lib_logger.info(
876
+ f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
877
+ )
878
  current_cred = await self.usage_manager.acquire_key(
879
+ available_keys=creds_to_try, model=model, deadline=deadline
 
 
880
  )
881
  key_acquired = True
882
  tried_creds.add(current_cred)
883
 
884
+ litellm_kwargs = self.all_providers.get_provider_kwargs(
885
+ **kwargs.copy()
886
+ )
887
  if "reasoning_effort" in kwargs:
888
  litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"]
889
  # Check body first for custom_reasoning_budget
890
  if "custom_reasoning_budget" in kwargs:
891
+ litellm_kwargs["custom_reasoning_budget"] = kwargs[
892
+ "custom_reasoning_budget"
893
+ ]
894
  else:
895
  custom_budget_header = None
896
+ if request and hasattr(request, "headers"):
897
+ custom_budget_header = request.headers.get(
898
+ "custom_reasoning_budget"
899
+ )
900
 
901
  if custom_budget_header is not None:
902
+ is_budget_enabled = custom_budget_header.lower() == "true"
903
+ litellm_kwargs["custom_reasoning_budget"] = (
904
+ is_budget_enabled
905
+ )
906
+
907
  # [NEW] Merge provider-specific params
908
  if provider in self.litellm_provider_params:
909
  litellm_kwargs["litellm_params"] = {
910
  **self.litellm_provider_params[provider],
911
+ **litellm_kwargs.get("litellm_params", {}),
912
  }
913
 
914
  provider_plugin = self._get_provider_instance(provider)
915
  if provider_plugin and provider_plugin.has_custom_logic():
916
+ lib_logger.debug(
917
+ f"Provider '{provider}' has custom logic. Delegating call."
918
+ )
919
  litellm_kwargs["credential_identifier"] = current_cred
920
+ litellm_kwargs["enable_request_logging"] = (
921
+ self.enable_request_logging
922
+ )
923
+
924
  for attempt in range(self.max_retries):
925
  try:
926
+ lib_logger.info(
927
+ f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})"
928
+ )
929
+
930
  if pre_request_callback:
931
  try:
932
+ await pre_request_callback(
933
+ request, litellm_kwargs
934
+ )
935
  except Exception as e:
936
  if self.abort_on_callback_error:
937
+ raise PreRequestCallbackError(
938
+ f"Pre-request callback failed: {e}"
939
+ ) from e
940
  else:
941
+ lib_logger.warning(
942
+ f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
943
+ )
944
+
945
+ response = await provider_plugin.acompletion(
946
+ self.http_client, **litellm_kwargs
947
+ )
948
+
949
+ lib_logger.info(
950
+ f"Stream connection established for credential ...{current_cred[-6:]}. Processing response."
951
+ )
952
 
953
  key_acquired = False
954
+ stream_generator = self._safe_streaming_wrapper(
955
+ response, current_cred, model, request
956
+ )
957
+
958
  async for chunk in stream_generator:
959
  yield chunk
960
  return
961
 
962
+ except (
963
+ StreamedAPIError,
964
+ litellm.RateLimitError,
965
+ httpx.HTTPStatusError,
966
+ ) as e:
967
+ if (
968
+ isinstance(e, httpx.HTTPStatusError)
969
+ and e.response.status_code != 429
970
+ ):
971
  raise e
972
 
973
  last_exception = e
974
  # If the exception is our custom wrapper, unwrap the original error
975
+ original_exc = getattr(e, "data", e)
976
  classified_error = classify_error(original_exc)
977
+ await self.usage_manager.record_failure(
978
+ current_cred, model, classified_error
979
+ )
980
+ lib_logger.warning(
981
+ f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during custom provider stream. Rotating key."
982
+ )
983
  break
984
 
985
+ except (
986
+ APIConnectionError,
987
+ litellm.InternalServerError,
988
+ litellm.ServiceUnavailableError,
989
+ ) as e:
990
  last_exception = e
991
+ log_failure(
992
+ api_key=current_cred,
993
+ model=model,
994
+ attempt=attempt + 1,
995
+ error=e,
996
+ request_headers=dict(request.headers)
997
+ if request
998
+ else {},
999
+ )
1000
  classified_error = classify_error(e)
1001
+ await self.usage_manager.record_failure(
1002
+ current_cred, model, classified_error
1003
+ )
1004
 
1005
  if attempt >= self.max_retries - 1:
1006
+ lib_logger.warning(
1007
+ f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key."
1008
+ )
1009
  break
1010
+
1011
+ wait_time = classified_error.retry_after or (
1012
+ 1 * (2**attempt)
1013
+ ) + random.uniform(0, 1)
1014
  remaining_budget = deadline - time.time()
1015
  if wait_time > remaining_budget:
1016
+ lib_logger.warning(
1017
+ f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early."
1018
+ )
1019
  break
1020
+
1021
+ error_message = str(e).split("\n")[0]
1022
+ lib_logger.warning(
1023
+ f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s."
1024
+ )
1025
  await asyncio.sleep(wait_time)
1026
  continue
1027
 
1028
  except Exception as e:
1029
  last_exception = e
1030
+ log_failure(
1031
+ api_key=current_cred,
1032
+ model=model,
1033
+ attempt=attempt + 1,
1034
+ error=e,
1035
+ request_headers=dict(request.headers)
1036
+ if request
1037
+ else {},
1038
+ )
1039
  classified_error = classify_error(e)
1040
+ lib_logger.warning(
1041
+ f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key."
1042
+ )
1043
+ if classified_error.error_type in [
1044
+ "invalid_request",
1045
+ "context_window_exceeded",
1046
+ "authentication",
1047
+ ]:
1048
  raise last_exception
1049
+ await self.usage_manager.record_failure(
1050
+ current_cred, model, classified_error
1051
+ )
1052
  break
1053
+
1054
  # If the inner loop breaks, it means the key failed and we need to rotate.
1055
  # Continue to the next iteration of the outer while loop to pick a new key.
1056
  continue
1057
 
1058
+ else: # This is the standard API Key / litellm-handled provider logic
1059
  is_oauth = provider in self.oauth_providers
1060
+ if is_oauth: # Standard OAuth provider (not custom)
1061
  # ... (logic to set headers) ...
1062
  pass
1063
+ else: # API Key
1064
  litellm_kwargs["api_key"] = current_cred
1065
 
1066
  provider_instance = self._get_provider_instance(provider)
1067
  if provider_instance:
1068
  if "safety_settings" in litellm_kwargs:
1069
+ converted_settings = (
1070
+ provider_instance.convert_safety_settings(
1071
+ litellm_kwargs["safety_settings"]
1072
+ )
1073
+ )
1074
  if converted_settings is not None:
1075
  litellm_kwargs["safety_settings"] = converted_settings
1076
  else:
1077
  del litellm_kwargs["safety_settings"]
1078
+
1079
  if provider == "gemini" and provider_instance:
1080
+ provider_instance.handle_thinking_parameter(
1081
+ litellm_kwargs, model
1082
+ )
1083
  if provider == "nvidia_nim" and provider_instance:
1084
+ provider_instance.handle_thinking_parameter(
1085
+ litellm_kwargs, model
1086
+ )
1087
 
1088
  if "gemma-3" in model and "messages" in litellm_kwargs:
1089
+ litellm_kwargs["messages"] = [
1090
+ {"role": "user", "content": m["content"]}
1091
+ if m.get("role") == "system"
1092
+ else m
1093
+ for m in litellm_kwargs["messages"]
1094
+ ]
1095
+
1096
  litellm_kwargs = sanitize_request_payload(litellm_kwargs, model)
1097
 
1098
  # If the provider is 'qwen_code', set the custom provider to 'qwen'
1099
  # and strip the prefix from the model name for LiteLLM.
1100
  if provider == "qwen_code":
1101
  litellm_kwargs["custom_llm_provider"] = "qwen"
1102
+ litellm_kwargs["model"] = model.split("/", 1)[1]
1103
 
1104
  for attempt in range(self.max_retries):
1105
  try:
1106
+ lib_logger.info(
1107
+ f"Attempting stream with credential ...{current_cred[-6:]} (Attempt {attempt + 1}/{self.max_retries})"
1108
+ )
1109
+
1110
  if pre_request_callback:
1111
  try:
1112
  await pre_request_callback(request, litellm_kwargs)
1113
  except Exception as e:
1114
  if self.abort_on_callback_error:
1115
+ raise PreRequestCallbackError(
1116
+ f"Pre-request callback failed: {e}"
1117
+ ) from e
1118
  else:
1119
+ lib_logger.warning(
1120
+ f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
1121
+ )
1122
+
1123
+ # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}")
1124
  response = await litellm.acompletion(
1125
  **litellm_kwargs,
1126
+ logger_fn=self._litellm_logger_callback,
1127
+ )
1128
+
1129
+ lib_logger.info(
1130
+ f"Stream connection established for credential ...{current_cred[-6:]}. Processing response."
1131
  )
 
 
1132
 
1133
  key_acquired = False
1134
+ stream_generator = self._safe_streaming_wrapper(
1135
+ response, current_cred, model, request
1136
+ )
1137
+
1138
  async for chunk in stream_generator:
1139
  yield chunk
1140
  return
1141
 
1142
  except (StreamedAPIError, litellm.RateLimitError) as e:
1143
  last_exception = e
1144
+
1145
  # This is the final, robust handler for streamed errors.
1146
  error_payload = {}
1147
  cleaned_str = None
1148
  # The actual exception might be wrapped in our StreamedAPIError.
1149
+ original_exc = getattr(e, "data", e)
1150
  classified_error = classify_error(original_exc)
1151
 
1152
  try:
1153
  # The full error JSON is in the string representation of the exception.
1154
+ json_str_match = re.search(
1155
+ r"(\{.*\})", str(original_exc), re.DOTALL
1156
+ )
1157
  if json_str_match:
1158
  # The string may contain byte-escaped characters (e.g., \\n).
1159
+ cleaned_str = codecs.decode(
1160
+ json_str_match.group(1), "unicode_escape"
1161
+ )
1162
  error_payload = json.loads(cleaned_str)
1163
  except (json.JSONDecodeError, TypeError):
1164
+ lib_logger.warning(
1165
+ "Could not parse JSON details from streamed error exception."
1166
+ )
1167
  error_payload = {}
1168
+
1169
  # Now, log the failure with the extracted raw response.
1170
  log_failure(
1171
  api_key=current_cred,
1172
  model=model,
1173
  attempt=attempt + 1,
1174
  error=e,
1175
+ request_headers=dict(request.headers)
1176
+ if request
1177
+ else {},
1178
+ raw_response_text=cleaned_str,
1179
  )
1180
 
1181
  error_details = error_payload.get("error", {})
1182
  error_status = error_details.get("status", "")
1183
  # Fallback to the full string if parsing fails.
1184
+ error_message_text = error_details.get(
1185
+ "message", str(original_exc)
1186
+ )
1187
 
1188
+ if (
1189
+ "quota" in error_message_text.lower()
1190
+ or "resource_exhausted" in error_status.lower()
1191
+ ):
1192
  consecutive_quota_failures += 1
1193
+ lib_logger.warning(
1194
+ f"Credential ...{current_cred[-6:]} hit a quota limit. This is consecutive failure #{consecutive_quota_failures} for this request."
1195
+ )
1196
 
1197
  quota_value = "N/A"
1198
  quota_id = "N/A"
1199
+ if "details" in error_details and isinstance(
1200
+ error_details.get("details"), list
1201
+ ):
1202
  for detail in error_details["details"]:
1203
  if isinstance(detail.get("violations"), list):
1204
  for violation in detail["violations"]:
1205
  if "quotaValue" in violation:
1206
+ quota_value = violation[
1207
+ "quotaValue"
1208
+ ]
1209
  if "quotaId" in violation:
1210
  quota_id = violation["quotaId"]
1211
+ if (
1212
+ quota_value != "N/A"
1213
+ and quota_id != "N/A"
1214
+ ):
1215
  break
1216
+
1217
+ await self.usage_manager.record_failure(
1218
+ current_cred, model, classified_error
1219
+ )
1220
 
1221
  if consecutive_quota_failures >= 3:
1222
  console_log_message = (
 
1229
  f"Last Error Message: '{error_message_text}'. Limit: {quota_value} (Quota ID: {quota_id})."
1230
  )
1231
  lib_logger.error(console_log_message)
1232
+
1233
  yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n"
1234
  yield "data: [DONE]\n\n"
1235
  return
1236
+
1237
  else:
1238
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
1239
+ lib_logger.warning(
1240
+ f"Quota error on credential ...{current_cred[-6:]} (failure {consecutive_quota_failures}/3). Rotating key silently."
1241
+ )
1242
  break
1243
+
1244
  else:
1245
  consecutive_quota_failures = 0
1246
  # [MODIFIED] Do not yield to the client. Just log and break to rotate the key.
1247
+ lib_logger.warning(
1248
+ f"Credential ...{current_cred[-6:]} encountered a recoverable error ({classified_error.error_type}) during stream. Rotating key silently."
1249
+ )
1250
+
1251
+ if (
1252
+ classified_error.error_type == "rate_limit"
1253
+ and classified_error.status_code == 429
1254
+ ):
1255
+ cooldown_duration = (
1256
+ classified_error.retry_after or 60
1257
+ )
1258
+ await self.cooldown_manager.start_cooldown(
1259
+ provider, cooldown_duration
1260
+ )
1261
+ lib_logger.warning(
1262
+ f"IP-based rate limit detected for {provider}. Starting a {cooldown_duration}-second global cooldown."
1263
+ )
1264
+
1265
+ await self.usage_manager.record_failure(
1266
+ current_cred, model, classified_error
1267
+ )
1268
  break
1269
 
1270
+ except (
1271
+ APIConnectionError,
1272
+ litellm.InternalServerError,
1273
+ litellm.ServiceUnavailableError,
1274
+ ) as e:
1275
  consecutive_quota_failures = 0
1276
  last_exception = e
1277
+ log_failure(
1278
+ api_key=current_cred,
1279
+ model=model,
1280
+ attempt=attempt + 1,
1281
+ error=e,
1282
+ request_headers=dict(request.headers)
1283
+ if request
1284
+ else {},
1285
+ )
1286
  classified_error = classify_error(e)
1287
+ await self.usage_manager.record_failure(
1288
+ current_cred, model, classified_error
1289
+ )
1290
 
1291
  if attempt >= self.max_retries - 1:
1292
+ lib_logger.warning(
1293
+ f"Credential ...{current_cred[-6:]} failed after max retries for model {model} due to a server error. Rotating key silently."
1294
+ )
1295
  # [MODIFIED] Do not yield to the client here.
1296
  break
1297
+
1298
+ wait_time = classified_error.retry_after or (
1299
+ 1 * (2**attempt)
1300
+ ) + random.uniform(0, 1)
1301
  remaining_budget = deadline - time.time()
1302
  if wait_time > remaining_budget:
1303
+ lib_logger.warning(
1304
+ f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early."
1305
+ )
1306
  break
1307
+
1308
+ error_message = str(e).split("\n")[0]
1309
+ lib_logger.warning(
1310
+ f"Credential ...{current_cred[-6:]} encountered a server error for model {model}. Reason: '{error_message}'. Retrying in {wait_time:.2f}s."
1311
+ )
1312
  await asyncio.sleep(wait_time)
1313
  continue
1314
 
1315
  except Exception as e:
1316
  consecutive_quota_failures = 0
1317
  last_exception = e
1318
+ log_failure(
1319
+ api_key=current_cred,
1320
+ model=model,
1321
+ attempt=attempt + 1,
1322
+ error=e,
1323
+ request_headers=dict(request.headers)
1324
+ if request
1325
+ else {},
1326
+ )
1327
  classified_error = classify_error(e)
1328
 
1329
+ lib_logger.warning(
1330
+ f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {str(e)}. Rotating key."
1331
+ )
1332
 
1333
  if classified_error.status_code == 429:
1334
  cooldown_duration = classified_error.retry_after or 60
1335
+ await self.cooldown_manager.start_cooldown(
1336
+ provider, cooldown_duration
1337
+ )
1338
+ lib_logger.warning(
1339
+ f"IP-based rate limit detected for {provider} from generic stream exception. Starting a {cooldown_duration}-second global cooldown."
1340
+ )
1341
+
1342
+ if classified_error.error_type in [
1343
+ "invalid_request",
1344
+ "context_window_exceeded",
1345
+ "authentication",
1346
+ ]:
1347
  raise last_exception
1348
+
1349
  # [MODIFIED] Do not yield to the client here.
1350
+ await self.usage_manager.record_failure(
1351
+ current_cred, model, classified_error
1352
+ )
1353
  break
1354
 
1355
  finally:
1356
  if key_acquired and current_cred:
1357
  await self.usage_manager.release_key(current_cred, model)
1358
+
1359
  final_error_message = "Failed to complete the streaming request: No available API keys after rotation or global timeout exceeded."
1360
  if last_exception:
1361
  final_error_message = f"Failed to complete the streaming request. Last error: {str(last_exception)}"
1362
+ lib_logger.error(
1363
+ f"Streaming request failed after trying all keys. Last error: {last_exception}"
1364
+ )
1365
  else:
1366
  lib_logger.error(final_error_message)
1367
 
1368
+ error_data = {
1369
+ "error": {"message": final_error_message, "type": "proxy_error"}
1370
+ }
1371
  yield f"data: {json.dumps(error_data)}\n\n"
1372
  yield "data: [DONE]\n\n"
1373
 
1374
  except NoAvailableKeysError as e:
1375
+ lib_logger.error(
1376
+ f"A streaming request failed because no keys were available within the time budget: {e}"
1377
+ )
1378
  error_data = {"error": {"message": str(e), "type": "proxy_busy"}}
1379
  yield f"data: {json.dumps(error_data)}\n\n"
1380
  yield "data: [DONE]\n\n"
1381
  except Exception as e:
1382
  # This will now only catch fatal errors that should be raised, like invalid requests.
1383
+ lib_logger.error(
1384
+ f"An unhandled exception occurred in streaming retry logic: {e}",
1385
+ exc_info=True,
1386
+ )
1387
+ error_data = {
1388
+ "error": {
1389
+ "message": f"An unexpected error occurred: {str(e)}",
1390
+ "type": "proxy_internal_error",
1391
+ }
1392
+ }
1393
  yield f"data: {json.dumps(error_data)}\n\n"
1394
  yield "data: [DONE]\n\n"
1395
 
1396
+ def acompletion(
1397
+ self,
1398
+ request: Optional[Any] = None,
1399
+ pre_request_callback: Optional[callable] = None,
1400
+ **kwargs,
1401
+ ) -> Union[Any, AsyncGenerator[str, None]]:
1402
  """
1403
  Dispatcher for completion requests.
1404
 
 
1417
  kwargs["stream_options"] = {}
1418
  if "include_usage" not in kwargs["stream_options"]:
1419
  kwargs["stream_options"]["include_usage"] = True
1420
+ return self._streaming_acompletion_with_retry(
1421
+ request=request, pre_request_callback=pre_request_callback, **kwargs
1422
+ )
1423
  else:
1424
+ return self._execute_with_retry(
1425
+ litellm.acompletion,
1426
+ request=request,
1427
+ pre_request_callback=pre_request_callback,
1428
+ **kwargs,
1429
+ )
1430
+
1431
+ def aembedding(
1432
+ self,
1433
+ request: Optional[Any] = None,
1434
+ pre_request_callback: Optional[callable] = None,
1435
+ **kwargs,
1436
+ ) -> Any:
1437
  """
1438
  Executes an embedding request with retry logic.
1439
 
 
1447
  Returns:
1448
  The embedding response object, or None if all retries fail.
1449
  """
1450
+ return self._execute_with_retry(
1451
+ litellm.aembedding,
1452
+ request=request,
1453
+ pre_request_callback=pre_request_callback,
1454
+ **kwargs,
1455
+ )
1456
 
1457
  def token_count(self, **kwargs) -> int:
1458
  """Calculates the number of tokens for a given text or list of messages."""
 
1495
  for credential in shuffled_credentials:
1496
  try:
1497
  # Display last 6 chars for API keys, or the filename for OAuth paths
1498
+ cred_display = (
1499
+ credential[-6:]
1500
+ if not os.path.isfile(credential)
1501
+ else os.path.basename(credential)
1502
+ )
1503
+ lib_logger.debug(
1504
+ f"Attempting to get models for {provider} with credential ...{cred_display}"
1505
+ )
1506
+ models = await provider_instance.get_models(
1507
+ credential, self.http_client
1508
+ )
1509
+ lib_logger.info(
1510
+ f"Got {len(models)} models for provider: {provider}"
1511
+ )
1512
 
1513
  # Whitelist and blacklist logic
1514
  final_models = []
 
1524
  final_models.append(m)
1525
 
1526
  if len(final_models) != len(models):
1527
+ lib_logger.info(
1528
+ f"Filtered out {len(models) - len(final_models)} models for provider {provider}."
1529
+ )
1530
 
1531
  self._model_list_cache[provider] = final_models
1532
  return final_models
1533
  except Exception as e:
1534
  classified_error = classify_error(e)
1535
+ cred_display = (
1536
+ credential[-6:]
1537
+ if not os.path.isfile(credential)
1538
+ else os.path.basename(credential)
1539
+ )
1540
+ lib_logger.debug(
1541
+ f"Failed to get models for provider {provider} with credential ...{cred_display}: {classified_error.error_type}. Trying next credential."
1542
+ )
1543
+ continue # Try the next credential
1544
 
1545
+ lib_logger.error(
1546
+ f"Failed to get models for provider {provider} after trying all credentials."
1547
+ )
1548
  return []
1549
 
1550
+ async def get_all_available_models(
1551
+ self, grouped: bool = True
1552
+ ) -> Union[Dict[str, List[str]], List[str]]:
1553
  """Returns a list of all available models, either grouped by provider or as a flat list."""
1554
  lib_logger.info("Getting all available models...")
1555
+
1556
  all_providers = list(self.all_credentials.keys())
1557
  tasks = [self.get_available_models(provider) for provider in all_providers]
1558
  results = await asyncio.gather(*tasks, return_exceptions=True)
 
1560
  all_provider_models = {}
1561
  for provider, result in zip(all_providers, results):
1562
  if isinstance(result, Exception):
1563
+ lib_logger.error(
1564
+ f"Failed to get models for provider {provider}: {result}"
1565
+ )
1566
  all_provider_models[provider] = []
1567
  else:
1568
  all_provider_models[provider] = result
1569
+
1570
  lib_logger.info("Finished getting all available models.")
1571
  if grouped:
1572
  return all_provider_models
src/rotator_library/error_handler.py CHANGED
@@ -3,19 +3,42 @@ import json
3
  from typing import Optional, Dict, Any
4
  import httpx
5
 
6
- from litellm.exceptions import APIConnectionError, RateLimitError, ServiceUnavailableError, AuthenticationError, InvalidRequestError, BadRequestError, OpenAIError, InternalServerError, Timeout, ContextWindowExceededError
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class NoAvailableKeysError(Exception):
9
  """Raised when no API keys are available for a request after waiting."""
 
10
  pass
11
 
 
12
  class PreRequestCallbackError(Exception):
13
  """Raised when a pre-request callback fails."""
 
14
  pass
15
 
 
16
  class ClassifiedError:
17
  """A structured representation of a classified error."""
18
- def __init__(self, error_type: str, original_exception: Exception, status_code: Optional[int] = None, retry_after: Optional[int] = None):
 
 
 
 
 
 
 
19
  self.error_type = error_type
20
  self.original_exception = original_exception
21
  self.status_code = status_code
@@ -24,6 +47,7 @@ class ClassifiedError:
24
  def __str__(self):
25
  return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
26
 
 
27
  def get_retry_after(error: Exception) -> Optional[int]:
28
  """
29
  Extracts the 'retry-after' duration in seconds from an exception message.
@@ -34,31 +58,31 @@ def get_retry_after(error: Exception) -> Optional[int]:
34
  # 1. Try to parse JSON from the error string to find 'retryDelay'
35
  try:
36
  # It's common for the actual JSON to be embedded in the string representation
37
- json_match = re.search(r'(\{.*\})', error_str)
38
  if json_match:
39
  error_json = json.loads(json_match.group(1))
40
- retry_info = error_json.get('error', {}).get('details', [{}])[0]
41
- if retry_info.get('@type') == 'type.googleapis.com/google.rpc.RetryInfo':
42
- delay_str = retry_info.get('retryDelay', {}).get('seconds')
43
  if delay_str:
44
  return int(delay_str)
45
  # Fallback for the other format
46
- delay_str = retry_info.get('retryDelay')
47
- if isinstance(delay_str, str) and delay_str.endswith('s'):
48
  return int(delay_str[:-1])
49
 
50
  except (json.JSONDecodeError, IndexError, KeyError, TypeError):
51
- pass # If JSON parsing fails, proceed to regex and attribute checks
52
 
53
  # 2. Common regex patterns for 'retry-after'
54
  patterns = [
55
- r'retry after:?\s*(\d+)',
56
- r'retry_after:?\s*(\d+)',
57
- r'retry in\s*(\d+)\s*seconds',
58
- r'wait for\s*(\d+)\s*seconds',
59
  r'"retryDelay":\s*"(\d+)s"',
60
  ]
61
-
62
  for pattern in patterns:
63
  match = re.search(pattern, error_str)
64
  if match:
@@ -66,104 +90,128 @@ def get_retry_after(error: Exception) -> Optional[int]:
66
  return int(match.group(1))
67
  except (ValueError, IndexError):
68
  continue
69
-
70
  # 3. Handle cases where the error object itself has the attribute
71
- if hasattr(error, 'retry_after'):
72
- value = getattr(error, 'retry_after')
73
  if isinstance(value, int):
74
  return value
75
  if isinstance(value, str) and value.isdigit():
76
  return int(value)
77
-
78
  return None
79
 
 
80
  def classify_error(e: Exception) -> ClassifiedError:
81
  """
82
  Classifies an exception into a structured ClassifiedError object.
83
  Now handles both litellm and httpx exceptions.
84
  """
85
- status_code = getattr(e, 'status_code', None)
86
- if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
87
  status_code = e.response.status_code
88
  if status_code == 401:
89
- return ClassifiedError(error_type='authentication', original_exception=e, status_code=status_code)
 
 
 
 
90
  if status_code == 429:
91
  retry_after = get_retry_after(e)
92
- return ClassifiedError(error_type='rate_limit', original_exception=e, status_code=status_code, retry_after=retry_after)
 
 
 
 
 
93
  if 400 <= status_code < 500:
94
- return ClassifiedError(error_type='invalid_request', original_exception=e, status_code=status_code)
 
 
 
 
95
  if 500 <= status_code:
96
- return ClassifiedError(error_type='server_error', original_exception=e, status_code=status_code)
97
-
98
- if isinstance(e, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)): # [NEW]
99
- return ClassifiedError(error_type='api_connection', original_exception=e, status_code=status_code)
 
 
 
 
 
 
100
 
101
  if isinstance(e, PreRequestCallbackError):
102
  return ClassifiedError(
103
- error_type='pre_request_callback_error',
104
  original_exception=e,
105
- status_code=400 # Treat as a bad request
106
  )
107
 
108
  if isinstance(e, RateLimitError):
109
  retry_after = get_retry_after(e)
110
  return ClassifiedError(
111
- error_type='rate_limit',
112
  original_exception=e,
113
  status_code=status_code or 429,
114
- retry_after=retry_after
115
  )
116
-
117
  if isinstance(e, (AuthenticationError,)):
118
  return ClassifiedError(
119
- error_type='authentication',
120
  original_exception=e,
121
- status_code=status_code or 401
122
  )
123
-
124
  if isinstance(e, (InvalidRequestError, BadRequestError)):
125
  return ClassifiedError(
126
- error_type='invalid_request',
127
  original_exception=e,
128
- status_code=status_code or 400
129
  )
130
-
131
  if isinstance(e, ContextWindowExceededError):
132
  return ClassifiedError(
133
- error_type='context_window_exceeded',
134
  original_exception=e,
135
- status_code=status_code or 400
136
  )
137
 
138
  if isinstance(e, (APIConnectionError, Timeout)):
139
  return ClassifiedError(
140
- error_type='api_connection',
141
  original_exception=e,
142
- status_code=status_code or 503 # Treat like a server error
143
  )
144
 
145
  if isinstance(e, (ServiceUnavailableError, InternalServerError, OpenAIError)):
146
  # These are often temporary server-side issues
147
  return ClassifiedError(
148
- error_type='server_error',
149
  original_exception=e,
150
- status_code=status_code or 503
151
  )
152
 
153
  # Fallback for any other unclassified errors
154
  return ClassifiedError(
155
- error_type='unknown',
156
- original_exception=e,
157
- status_code=status_code
158
  )
159
 
 
160
  def is_rate_limit_error(e: Exception) -> bool:
161
  """Checks if the exception is a rate limit error."""
162
  return isinstance(e, RateLimitError)
163
 
 
164
  def is_server_error(e: Exception) -> bool:
165
  """Checks if the exception is a temporary server-side error."""
166
- return isinstance(e, (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError))
 
 
 
 
167
 
168
  def is_unrecoverable_error(e: Exception) -> bool:
169
  """
@@ -172,17 +220,58 @@ def is_unrecoverable_error(e: Exception) -> bool:
172
  """
173
  return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError))
174
 
 
175
  class AllProviders:
176
  """
177
  A class to handle provider-specific settings, such as custom API bases.
 
178
  """
 
179
  def __init__(self):
180
  self.providers = {
181
  "chutes": {
182
  "api_base": "https://llm.chutes.ai/v1",
183
- "model_prefix": "openai/"
184
  }
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]:
188
  """
@@ -194,17 +283,22 @@ class AllProviders:
194
 
195
  provider = self._get_provider_from_model(model)
196
  provider_settings = self.providers.get(provider, {})
197
-
198
  if "api_base" in provider_settings:
199
  kwargs["api_base"] = provider_settings["api_base"]
200
-
201
- if "model_prefix" in provider_settings:
202
- kwargs["model"] = f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}"
203
-
 
 
 
 
 
204
  return kwargs
205
 
206
  def _get_provider_from_model(self, model: str) -> str:
207
  """
208
  Determines the provider from the model name.
209
  """
210
- return model.split('/')[0]
 
3
  from typing import Optional, Dict, Any
4
  import httpx
5
 
6
+ from litellm.exceptions import (
7
+ APIConnectionError,
8
+ RateLimitError,
9
+ ServiceUnavailableError,
10
+ AuthenticationError,
11
+ InvalidRequestError,
12
+ BadRequestError,
13
+ OpenAIError,
14
+ InternalServerError,
15
+ Timeout,
16
+ ContextWindowExceededError,
17
+ )
18
+
19
 
20
  class NoAvailableKeysError(Exception):
21
  """Raised when no API keys are available for a request after waiting."""
22
+
23
  pass
24
 
25
+
26
  class PreRequestCallbackError(Exception):
27
  """Raised when a pre-request callback fails."""
28
+
29
  pass
30
 
31
+
32
  class ClassifiedError:
33
  """A structured representation of a classified error."""
34
+
35
+ def __init__(
36
+ self,
37
+ error_type: str,
38
+ original_exception: Exception,
39
+ status_code: Optional[int] = None,
40
+ retry_after: Optional[int] = None,
41
+ ):
42
  self.error_type = error_type
43
  self.original_exception = original_exception
44
  self.status_code = status_code
 
47
  def __str__(self):
48
  return f"ClassifiedError(type={self.error_type}, status={self.status_code}, retry_after={self.retry_after}, original_exc={self.original_exception})"
49
 
50
+
51
  def get_retry_after(error: Exception) -> Optional[int]:
52
  """
53
  Extracts the 'retry-after' duration in seconds from an exception message.
 
58
  # 1. Try to parse JSON from the error string to find 'retryDelay'
59
  try:
60
  # It's common for the actual JSON to be embedded in the string representation
61
+ json_match = re.search(r"(\{.*\})", error_str)
62
  if json_match:
63
  error_json = json.loads(json_match.group(1))
64
+ retry_info = error_json.get("error", {}).get("details", [{}])[0]
65
+ if retry_info.get("@type") == "type.googleapis.com/google.rpc.RetryInfo":
66
+ delay_str = retry_info.get("retryDelay", {}).get("seconds")
67
  if delay_str:
68
  return int(delay_str)
69
  # Fallback for the other format
70
+ delay_str = retry_info.get("retryDelay")
71
+ if isinstance(delay_str, str) and delay_str.endswith("s"):
72
  return int(delay_str[:-1])
73
 
74
  except (json.JSONDecodeError, IndexError, KeyError, TypeError):
75
+ pass # If JSON parsing fails, proceed to regex and attribute checks
76
 
77
  # 2. Common regex patterns for 'retry-after'
78
  patterns = [
79
+ r"retry after:?\s*(\d+)",
80
+ r"retry_after:?\s*(\d+)",
81
+ r"retry in\s*(\d+)\s*seconds",
82
+ r"wait for\s*(\d+)\s*seconds",
83
  r'"retryDelay":\s*"(\d+)s"',
84
  ]
85
+
86
  for pattern in patterns:
87
  match = re.search(pattern, error_str)
88
  if match:
 
90
  return int(match.group(1))
91
  except (ValueError, IndexError):
92
  continue
93
+
94
  # 3. Handle cases where the error object itself has the attribute
95
+ if hasattr(error, "retry_after"):
96
+ value = getattr(error, "retry_after")
97
  if isinstance(value, int):
98
  return value
99
  if isinstance(value, str) and value.isdigit():
100
  return int(value)
101
+
102
  return None
103
 
104
+
105
  def classify_error(e: Exception) -> ClassifiedError:
106
  """
107
  Classifies an exception into a structured ClassifiedError object.
108
  Now handles both litellm and httpx exceptions.
109
  """
110
+ status_code = getattr(e, "status_code", None)
111
+ if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
112
  status_code = e.response.status_code
113
  if status_code == 401:
114
+ return ClassifiedError(
115
+ error_type="authentication",
116
+ original_exception=e,
117
+ status_code=status_code,
118
+ )
119
  if status_code == 429:
120
  retry_after = get_retry_after(e)
121
+ return ClassifiedError(
122
+ error_type="rate_limit",
123
+ original_exception=e,
124
+ status_code=status_code,
125
+ retry_after=retry_after,
126
+ )
127
  if 400 <= status_code < 500:
128
+ return ClassifiedError(
129
+ error_type="invalid_request",
130
+ original_exception=e,
131
+ status_code=status_code,
132
+ )
133
  if 500 <= status_code:
134
+ return ClassifiedError(
135
+ error_type="server_error", original_exception=e, status_code=status_code
136
+ )
137
+
138
+ if isinstance(
139
+ e, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)
140
+ ): # [NEW]
141
+ return ClassifiedError(
142
+ error_type="api_connection", original_exception=e, status_code=status_code
143
+ )
144
 
145
  if isinstance(e, PreRequestCallbackError):
146
  return ClassifiedError(
147
+ error_type="pre_request_callback_error",
148
  original_exception=e,
149
+ status_code=400, # Treat as a bad request
150
  )
151
 
152
  if isinstance(e, RateLimitError):
153
  retry_after = get_retry_after(e)
154
  return ClassifiedError(
155
+ error_type="rate_limit",
156
  original_exception=e,
157
  status_code=status_code or 429,
158
+ retry_after=retry_after,
159
  )
160
+
161
  if isinstance(e, (AuthenticationError,)):
162
  return ClassifiedError(
163
+ error_type="authentication",
164
  original_exception=e,
165
+ status_code=status_code or 401,
166
  )
167
+
168
  if isinstance(e, (InvalidRequestError, BadRequestError)):
169
  return ClassifiedError(
170
+ error_type="invalid_request",
171
  original_exception=e,
172
+ status_code=status_code or 400,
173
  )
174
+
175
  if isinstance(e, ContextWindowExceededError):
176
  return ClassifiedError(
177
+ error_type="context_window_exceeded",
178
  original_exception=e,
179
+ status_code=status_code or 400,
180
  )
181
 
182
  if isinstance(e, (APIConnectionError, Timeout)):
183
  return ClassifiedError(
184
+ error_type="api_connection",
185
  original_exception=e,
186
+ status_code=status_code or 503, # Treat like a server error
187
  )
188
 
189
  if isinstance(e, (ServiceUnavailableError, InternalServerError, OpenAIError)):
190
  # These are often temporary server-side issues
191
  return ClassifiedError(
192
+ error_type="server_error",
193
  original_exception=e,
194
+ status_code=status_code or 503,
195
  )
196
 
197
  # Fallback for any other unclassified errors
198
  return ClassifiedError(
199
+ error_type="unknown", original_exception=e, status_code=status_code
 
 
200
  )
201
 
202
+
203
  def is_rate_limit_error(e: Exception) -> bool:
204
  """Checks if the exception is a rate limit error."""
205
  return isinstance(e, RateLimitError)
206
 
207
+
208
  def is_server_error(e: Exception) -> bool:
209
  """Checks if the exception is a temporary server-side error."""
210
+ return isinstance(
211
+ e,
212
+ (ServiceUnavailableError, APIConnectionError, InternalServerError, OpenAIError),
213
+ )
214
+
215
 
216
  def is_unrecoverable_error(e: Exception) -> bool:
217
  """
 
220
  """
221
  return isinstance(e, (InvalidRequestError, AuthenticationError, BadRequestError))
222
 
223
+
224
  class AllProviders:
225
  """
226
  A class to handle provider-specific settings, such as custom API bases.
227
+ Supports custom OpenAI-compatible providers configured via environment variables.
228
  """
229
+
230
  def __init__(self):
231
  self.providers = {
232
  "chutes": {
233
  "api_base": "https://llm.chutes.ai/v1",
234
+ "model_prefix": "openai/",
235
  }
236
  }
237
+ # Load custom OpenAI-compatible providers from environment
238
+ self._load_custom_providers()
239
+
240
+ def _load_custom_providers(self):
241
+ """
242
+ Loads custom OpenAI-compatible providers from environment variables.
243
+ Looks for environment variables in the format: PROVIDER_API_BASE
244
+ where PROVIDER is the name of the custom provider.
245
+ """
246
+ import os
247
+
248
+ # Get all environment variables that end with _API_BASE
249
+ for env_var in os.environ:
250
+ if env_var.endswith("_API_BASE"):
251
+ provider_name = env_var.split("_API_BASE")[
252
+ 0
253
+ ].lower() # Remove '_API_BASE' suffix and lowercase
254
+
255
+ # Skip known providers that are already handled
256
+ if provider_name in [
257
+ "openai",
258
+ "anthropic",
259
+ "google",
260
+ "gemini",
261
+ "nvidia",
262
+ "mistral",
263
+ "cohere",
264
+ "groq",
265
+ "openrouter",
266
+ ]:
267
+ continue
268
+
269
+ api_base = os.getenv(env_var)
270
+ if api_base:
271
+ self.providers[provider_name] = {
272
+ "api_base": api_base.rstrip("/") if api_base else None,
273
+ "model_prefix": None, # No prefix for custom providers
274
+ }
275
 
276
  def get_provider_kwargs(self, **kwargs) -> Dict[str, Any]:
277
  """
 
283
 
284
  provider = self._get_provider_from_model(model)
285
  provider_settings = self.providers.get(provider, {})
286
+
287
  if "api_base" in provider_settings:
288
  kwargs["api_base"] = provider_settings["api_base"]
289
+
290
+ if (
291
+ "model_prefix" in provider_settings
292
+ and provider_settings["model_prefix"] is not None
293
+ ):
294
+ kwargs["model"] = (
295
+ f"{provider_settings['model_prefix']}{model.split('/', 1)[1]}"
296
+ )
297
+
298
  return kwargs
299
 
300
  def _get_provider_from_model(self, model: str) -> str:
301
  """
302
  Determines the provider from the model name.
303
  """
304
+ return model.split("/")[0]
src/rotator_library/providers/openai_compatible_provider.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import logging
4
+ from typing import List, Dict, Any, Optional
5
+ from .provider_interface import ProviderInterface
6
+
7
+ lib_logger = logging.getLogger("rotator_library")
8
+ lib_logger.propagate = False
9
+ if not lib_logger.handlers:
10
+ lib_logger.addHandler(logging.NullHandler())
11
+
12
+
13
+ class OpenAICompatibleProvider(ProviderInterface):
14
+ """
15
+ Generic provider implementation for any OpenAI-compatible API.
16
+ This provider can be configured via environment variables to support
17
+ custom OpenAI-compatible endpoints without requiring code changes.
18
+ """
19
+
20
+ def __init__(self, provider_name: str):
21
+ self.provider_name = provider_name
22
+ # Get API base URL from environment
23
+ self.api_base = os.getenv(f"{provider_name.upper()}_API_BASE")
24
+ if not self.api_base:
25
+ raise ValueError(
26
+ f"Environment variable {provider_name.upper()}_API_BASE is required for OpenAI-compatible provider"
27
+ )
28
+
29
+ async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
30
+ """
31
+ Fetches the list of available models from the OpenAI-compatible API.
32
+ """
33
+ try:
34
+ models_url = f"{self.api_base.rstrip('/')}/models"
35
+ response = await client.get(
36
+ models_url, headers={"Authorization": f"Bearer {api_key}"}
37
+ )
38
+ response.raise_for_status()
39
+ return [
40
+ f"{self.provider_name}/{model['id']}"
41
+ for model in response.json().get("data", [])
42
+ ]
43
+ except httpx.RequestError as e:
44
+ lib_logger.error(f"Failed to fetch models for {self.provider_name}: {e}")
45
+ return []
46
+ except Exception as e:
47
+ lib_logger.error(
48
+ f"Unexpected error fetching models for {self.provider_name}: {e}"
49
+ )
50
+ return []
51
+
52
+ def has_custom_logic(self) -> bool:
53
+ """
54
+ Returns False since we want to use the standard litellm flow
55
+ with just custom API base configuration.
56
+ """
57
+ return False
58
+
59
+ async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
60
+ """
61
+ Returns the standard Bearer token header for API key authentication.
62
+ """
63
+ return {"Authorization": f"Bearer {credential_identifier}"}