mirrobot-agent[bot] commited on
Commit
a1cc875
·
1 Parent(s): 7cb148b

fix: improve error handling implementation based on code review

Browse files

- Fix credential counting to track unique credentials (RequestErrorAccumulator)
- Move import os to module level in mask_credential function
- Fix status code check to use explicit 'is not None' comparison
- Improve context window error detection with more specific patterns
- Correct comment about server error classification
- Remove redundant '1 *' in exponential backoff calculations
- Add warning log for unreachable None return path
- Remove redundant error_accumulator.model/provider assignments
- Remove access to private _content attribute in failure_logger
- Add circular reference detection in error chain loop
- Reorder error recording to occur after should_rotate_on_error check

These changes address issues identified in both mirrobot-agent and
GitHub Copilot code reviews.

src/rotator_library/client.py CHANGED
@@ -71,7 +71,7 @@ class RotatingClient:
71
  ):
72
  """
73
  Initialize the RotatingClient with intelligent credential rotation.
74
-
75
  Args:
76
  api_keys: Dictionary mapping provider names to lists of API keys
77
  oauth_credentials: Dictionary mapping provider names to OAuth credential paths
@@ -140,8 +140,7 @@ class RotatingClient:
140
  self.global_timeout = global_timeout
141
  self.abort_on_callback_error = abort_on_callback_error
142
  self.usage_manager = UsageManager(
143
- file_path=usage_file_path,
144
- rotation_tolerance=rotation_tolerance
145
  )
146
  self._model_list_cache = {}
147
  self._provider_plugins = PROVIDER_PLUGINS
@@ -160,7 +159,9 @@ class RotatingClient:
160
  # Validate all values are >= 1
161
  for provider, max_val in self.max_concurrent_requests_per_key.items():
162
  if max_val < 1:
163
- lib_logger.warning(f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1.")
 
 
164
  self.max_concurrent_requests_per_key[provider] = 1
165
 
166
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
@@ -368,7 +369,9 @@ class RotatingClient:
368
 
369
  return kwargs
370
 
371
- def _apply_default_safety_settings(self, litellm_kwargs: Dict[str, Any], provider: str):
 
 
372
  """
373
  Ensure default Gemini safety settings are present when calling the Gemini provider.
374
  This will not override any explicit settings provided by the request. It accepts
@@ -397,22 +400,33 @@ class RotatingClient:
397
  ]
398
 
399
  # If generic form is present, ensure missing generic keys are filled in
400
- if "safety_settings" in litellm_kwargs and isinstance(litellm_kwargs["safety_settings"], dict):
 
 
401
  for k, v in default_generic.items():
402
  if k not in litellm_kwargs["safety_settings"]:
403
  litellm_kwargs["safety_settings"][k] = v
404
  return
405
 
406
  # If Gemini form is present, ensure missing gemini categories are appended
407
- if "safetySettings" in litellm_kwargs and isinstance(litellm_kwargs["safetySettings"], list):
408
- present = {item.get("category") for item in litellm_kwargs["safetySettings"] if isinstance(item, dict)}
 
 
 
 
 
 
409
  for d in default_gemini:
410
  if d["category"] not in present:
411
  litellm_kwargs["safetySettings"].append(d)
412
  return
413
 
414
  # Neither present: set generic defaults so provider conversion will translate them
415
- if "safety_settings" not in litellm_kwargs and "safetySettings" not in litellm_kwargs:
 
 
 
416
  litellm_kwargs["safety_settings"] = default_generic.copy()
417
 
418
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
@@ -430,10 +444,10 @@ class RotatingClient:
430
  """
431
  Lazily initializes and returns a provider instance.
432
  Only initializes providers that have configured credentials.
433
-
434
  Args:
435
  provider_name: The name of the provider to get an instance for.
436
-
437
  Returns:
438
  Provider instance if credentials exist, None otherwise.
439
  """
@@ -443,7 +457,7 @@ class RotatingClient:
443
  f"Skipping provider '{provider_name}' initialization: no credentials configured"
444
  )
445
  return None
446
-
447
  if provider_name not in self._provider_instances:
448
  if provider_name in self._provider_plugins:
449
  self._provider_instances[provider_name] = self._provider_plugins[
@@ -465,46 +479,47 @@ class RotatingClient:
465
  def _resolve_model_id(self, model: str, provider: str) -> str:
466
  """
467
  Resolves the actual model ID to send to the provider.
468
-
469
  For custom models with name/ID mappings, returns the ID.
470
  Otherwise, returns the model name unchanged.
471
-
472
  Args:
473
  model: Full model string with provider (e.g., "iflow/DS-v3.2")
474
  provider: Provider name (e.g., "iflow")
475
-
476
  Returns:
477
  Full model string with ID (e.g., "iflow/deepseek-v3.2")
478
  """
479
  # Extract model name from "provider/model_name" format
480
- model_name = model.split('/')[-1] if '/' in model else model
481
-
482
  # Try to get provider instance to check for model definitions
483
  provider_plugin = self._get_provider_instance(provider)
484
-
485
  # Check if provider has model definitions
486
- if provider_plugin and hasattr(provider_plugin, 'model_definitions'):
487
- model_id = provider_plugin.model_definitions.get_model_id(provider, model_name)
 
 
488
  if model_id and model_id != model_name:
489
  # Return with provider prefix
490
  return f"{provider}/{model_id}"
491
-
492
  # Fallback: use client's own model definitions
493
  model_id = self.model_definitions.get_model_id(provider, model_name)
494
  if model_id and model_id != model_name:
495
  return f"{provider}/{model_id}"
496
-
497
  # No conversion needed, return original
498
  return model
499
 
500
-
501
  async def _safe_streaming_wrapper(
502
  self, stream: Any, key: str, model: str, request: Optional[Any] = None
503
  ) -> AsyncGenerator[Any, None]:
504
  """
505
  A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
506
  and distinguishes between content and streamed errors.
507
-
508
  FINISH_REASON HANDLING:
509
  Providers just translate chunks - this wrapper handles ALL finish_reason logic:
510
  1. Strip finish_reason from intermediate chunks (litellm defaults to "stop")
@@ -541,7 +556,7 @@ class RotatingClient:
541
  chunk_dict = chunk.model_dump()
542
  else:
543
  chunk_dict = chunk
544
-
545
  # === FINISH_REASON LOGIC ===
546
  # Providers send raw chunks without finish_reason logic.
547
  # This wrapper determines finish_reason based on accumulated state.
@@ -549,19 +564,19 @@ class RotatingClient:
549
  choice = chunk_dict["choices"][0]
550
  delta = choice.get("delta", {})
551
  usage = chunk_dict.get("usage", {})
552
-
553
  # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls
554
  if delta.get("tool_calls"):
555
  has_tool_calls = True
556
  accumulated_finish_reason = "tool_calls"
557
-
558
  # Detect final chunk: has usage with completion_tokens > 0
559
  has_completion_tokens = (
560
- usage and
561
- isinstance(usage, dict) and
562
- usage.get("completion_tokens", 0) > 0
563
  )
564
-
565
  if has_completion_tokens:
566
  # FINAL CHUNK: Determine correct finish_reason
567
  if has_tool_calls:
@@ -577,7 +592,7 @@ class RotatingClient:
577
  # INTERMEDIATE CHUNK: Never emit finish_reason
578
  # (litellm.ModelResponse defaults to "stop" which is wrong)
579
  choice["finish_reason"] = None
580
-
581
  yield f"data: {json.dumps(chunk_dict)}\n\n"
582
 
583
  if hasattr(chunk, "usage") and chunk.usage:
@@ -726,12 +741,13 @@ class RotatingClient:
726
  # multiple keys have the same usage stats.
727
  credentials_for_provider = list(self.all_credentials[provider])
728
  random.shuffle(credentials_for_provider)
729
-
730
  # Filter out credentials that are unavailable (queued for re-auth)
731
  provider_plugin = self._get_provider_instance(provider)
732
- if provider_plugin and hasattr(provider_plugin, 'is_credential_available'):
733
  available_creds = [
734
- cred for cred in credentials_for_provider
 
735
  if provider_plugin.is_credential_available(cred)
736
  ]
737
  if available_creds:
@@ -744,7 +760,7 @@ class RotatingClient:
744
  kwargs = self._convert_model_params(**kwargs)
745
 
746
  # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
747
-
748
  # Resolve model ID early, before any credential operations
749
  # This ensures consistent model ID usage for acquisition, release, and tracking
750
  resolved_model = self._resolve_model_id(model, provider)
@@ -752,10 +768,10 @@ class RotatingClient:
752
  lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'")
753
  model = resolved_model
754
  kwargs["model"] = model # Ensure kwargs has the resolved model for litellm
755
-
756
  # [NEW] Filter by model tier requirement and build priority map
757
  credential_priorities = None
758
- if provider_plugin and hasattr(provider_plugin, 'get_model_tier_requirement'):
759
  required_tier = provider_plugin.get_model_tier_requirement(model)
760
  if required_tier is not None:
761
  # Filter OUT only credentials we KNOW are too low priority
@@ -763,9 +779,9 @@ class RotatingClient:
763
  incompatible_creds = []
764
  compatible_creds = []
765
  unknown_creds = []
766
-
767
  for cred in credentials_for_provider:
768
- if hasattr(provider_plugin, 'get_credential_priority'):
769
  priority = provider_plugin.get_credential_priority(cred)
770
  if priority is None:
771
  # Unknown priority - keep it, will be discovered on first use
@@ -779,7 +795,7 @@ class RotatingClient:
779
  else:
780
  # Provider doesn't support priorities - keep all
781
  unknown_creds.append(cred)
782
-
783
  # If we have any known-compatible or unknown credentials, use them
784
  tier_compatible_creds = compatible_creds + unknown_creds
785
  if tier_compatible_creds:
@@ -806,18 +822,18 @@ class RotatingClient:
806
  f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. "
807
  f"Request will likely fail."
808
  )
809
-
810
  # Build priority map for usage_manager
811
- if provider_plugin and hasattr(provider_plugin, 'get_credential_priority'):
812
  credential_priorities = {}
813
  for cred in credentials_for_provider:
814
  priority = provider_plugin.get_credential_priority(cred)
815
  if priority is not None:
816
  credential_priorities[cred] = priority
817
-
818
  if credential_priorities:
819
  lib_logger.debug(
820
- f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c)==p])}' for p in sorted(set(credential_priorities.values())))}"
821
  )
822
 
823
  # Initialize error accumulator for tracking errors across credential rotation
@@ -861,9 +877,11 @@ class RotatingClient:
861
  )
862
  max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
863
  current_cred = await self.usage_manager.acquire_key(
864
- available_keys=creds_to_try, model=model, deadline=deadline,
 
 
865
  max_concurrent=max_concurrent,
866
- credential_priorities=credential_priorities
867
  )
868
  key_acquired = True
869
  tried_creds.add(current_cred)
@@ -946,10 +964,14 @@ class RotatingClient:
946
  if provider_instance:
947
  # Ensure default Gemini safety settings are present (without overriding request)
948
  try:
949
- self._apply_default_safety_settings(litellm_kwargs, provider)
 
 
950
  except Exception:
951
  # If anything goes wrong here, avoid breaking the request flow.
952
- lib_logger.debug("Could not apply default safety settings; continuing.")
 
 
953
 
954
  if "safety_settings" in litellm_kwargs:
955
  converted_settings = (
@@ -1032,9 +1054,11 @@ class RotatingClient:
1032
 
1033
  # Extract a clean error message for the user-facing log
1034
  error_message = str(e).split("\n")[0]
1035
-
1036
  # Record in accumulator for client reporting
1037
- error_accumulator.record_error(current_cred, classified_error, error_message)
 
 
1038
 
1039
  lib_logger.info(
1040
  f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key."
@@ -1068,16 +1092,20 @@ class RotatingClient:
1068
  )
1069
  classified_error = classify_error(e)
1070
  error_message = str(e).split("\n")[0]
1071
-
1072
  # Provider-level error: don't increment consecutive failures
1073
  await self.usage_manager.record_failure(
1074
- current_cred, model, classified_error,
1075
- increment_consecutive_failures=False
 
 
1076
  )
1077
 
1078
  if attempt >= self.max_retries - 1:
1079
  # Record in accumulator only on final failure for this key
1080
- error_accumulator.record_error(current_cred, classified_error, error_message)
 
 
1081
  lib_logger.warning(
1082
  f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating."
1083
  )
@@ -1085,13 +1113,15 @@ class RotatingClient:
1085
 
1086
  # For temporary errors, wait before retrying with the same key.
1087
  wait_time = classified_error.retry_after or (
1088
- 1 * (2**attempt)
1089
  ) + random.uniform(0, 1)
1090
  remaining_budget = deadline - time.time()
1091
 
1092
  # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
1093
  if wait_time > remaining_budget:
1094
- error_accumulator.record_error(current_cred, classified_error, error_message)
 
 
1095
  lib_logger.warning(
1096
  f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key."
1097
  )
@@ -1115,34 +1145,44 @@ class RotatingClient:
1115
  if request
1116
  else {},
1117
  )
1118
-
1119
  classified_error = classify_error(e)
1120
  error_message = str(e).split("\n")[0]
1121
-
1122
- # Record in accumulator for client reporting
1123
- error_accumulator.record_error(current_cred, classified_error, error_message)
1124
-
1125
  lib_logger.warning(
1126
  f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})."
1127
  )
1128
-
1129
  # Check if this error should trigger rotation
1130
  if not should_rotate_on_error(classified_error):
1131
  lib_logger.error(
1132
  f"Non-recoverable error ({classified_error.error_type}). Failing request."
1133
  )
1134
  raise last_exception
1135
-
 
 
 
 
 
1136
  # Handle rate limits with cooldown
1137
- if classified_error.error_type in ["rate_limit", "quota_exceeded"]:
 
 
 
1138
  cooldown_duration = classified_error.retry_after or 60
1139
  await self.cooldown_manager.start_cooldown(
1140
  provider, cooldown_duration
1141
  )
1142
-
1143
  # Check if we should retry same key (server errors with retries left)
1144
- if should_retry_same_key(classified_error) and attempt < self.max_retries - 1:
1145
- wait_time = classified_error.retry_after or (1 * (2**attempt)) + random.uniform(0, 1)
 
 
 
 
 
1146
  remaining_budget = deadline - time.time()
1147
  if wait_time <= remaining_budget:
1148
  lib_logger.warning(
@@ -1150,12 +1190,14 @@ class RotatingClient:
1150
  )
1151
  await asyncio.sleep(wait_time)
1152
  continue
1153
-
1154
  # Record failure and rotate to next key
1155
  await self.usage_manager.record_failure(
1156
  current_cred, model, classified_error
1157
  )
1158
- lib_logger.info(f"Rotating to next key after {classified_error.error_type} error.")
 
 
1159
  break
1160
 
1161
  except Exception as e:
@@ -1178,16 +1220,17 @@ class RotatingClient:
1178
 
1179
  classified_error = classify_error(e)
1180
  error_message = str(e).split("\n")[0]
1181
-
1182
- # Record in accumulator for client reporting
1183
- error_accumulator.record_error(current_cred, classified_error, error_message)
1184
-
1185
  lib_logger.warning(
1186
  f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})."
1187
  )
1188
-
1189
  # Handle rate limits with cooldown
1190
- if classified_error.status_code == 429 or classified_error.error_type in ["rate_limit", "quota_exceeded"]:
 
 
 
 
1191
  cooldown_duration = classified_error.retry_after or 60
1192
  await self.cooldown_manager.start_cooldown(
1193
  provider, cooldown_duration
@@ -1200,6 +1243,11 @@ class RotatingClient:
1200
  )
1201
  raise last_exception
1202
 
 
 
 
 
 
1203
  await self.usage_manager.record_failure(
1204
  current_cred, model, classified_error
1205
  )
@@ -1211,15 +1259,19 @@ class RotatingClient:
1211
  # Check if we exhausted all credentials or timed out
1212
  if time.time() >= deadline:
1213
  error_accumulator.timeout_occurred = True
1214
-
1215
  if error_accumulator.has_errors():
1216
  # Log concise summary for server logs
1217
  lib_logger.error(error_accumulator.build_log_message())
1218
-
1219
  # Return the structured error response for the client
1220
  return error_accumulator.build_client_error_response()
1221
 
1222
  # Return None to indicate failure without error details (shouldn't normally happen)
 
 
 
 
1223
  return None
1224
 
1225
  async def _streaming_acompletion_with_retry(
@@ -1235,12 +1287,13 @@ class RotatingClient:
1235
  # Create a mutable copy of the keys and shuffle it.
1236
  credentials_for_provider = list(self.all_credentials[provider])
1237
  random.shuffle(credentials_for_provider)
1238
-
1239
  # Filter out credentials that are unavailable (queued for re-auth)
1240
  provider_plugin = self._get_provider_instance(provider)
1241
- if provider_plugin and hasattr(provider_plugin, 'is_credential_available'):
1242
  available_creds = [
1243
- cred for cred in credentials_for_provider
 
1244
  if provider_plugin.is_credential_available(cred)
1245
  ]
1246
  if available_creds:
@@ -1262,10 +1315,10 @@ class RotatingClient:
1262
  lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'")
1263
  model = resolved_model
1264
  kwargs["model"] = model # Ensure kwargs has the resolved model for litellm
1265
-
1266
  # [NEW] Filter by model tier requirement and build priority map
1267
  credential_priorities = None
1268
- if provider_plugin and hasattr(provider_plugin, 'get_model_tier_requirement'):
1269
  required_tier = provider_plugin.get_model_tier_requirement(model)
1270
  if required_tier is not None:
1271
  # Filter OUT only credentials we KNOW are too low priority
@@ -1273,9 +1326,9 @@ class RotatingClient:
1273
  incompatible_creds = []
1274
  compatible_creds = []
1275
  unknown_creds = []
1276
-
1277
  for cred in credentials_for_provider:
1278
- if hasattr(provider_plugin, 'get_credential_priority'):
1279
  priority = provider_plugin.get_credential_priority(cred)
1280
  if priority is None:
1281
  # Unknown priority - keep it, will be discovered on first use
@@ -1289,7 +1342,7 @@ class RotatingClient:
1289
  else:
1290
  # Provider doesn't support priorities - keep all
1291
  unknown_creds.append(cred)
1292
-
1293
  # If we have any known-compatible or unknown credentials, use them
1294
  tier_compatible_creds = compatible_creds + unknown_creds
1295
  if tier_compatible_creds:
@@ -1316,18 +1369,18 @@ class RotatingClient:
1316
  f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. "
1317
  f"Request will likely fail."
1318
  )
1319
-
1320
  # Build priority map for usage_manager
1321
- if provider_plugin and hasattr(provider_plugin, 'get_credential_priority'):
1322
  credential_priorities = {}
1323
  for cred in credentials_for_provider:
1324
  priority = provider_plugin.get_credential_priority(cred)
1325
  if priority is not None:
1326
  credential_priorities[cred] = priority
1327
-
1328
  if credential_priorities:
1329
  lib_logger.debug(
1330
- f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c)==p])}' for p in sorted(set(credential_priorities.values())))}"
1331
  )
1332
 
1333
  # Initialize error accumulator for tracking errors across credential rotation
@@ -1370,11 +1423,15 @@ class RotatingClient:
1370
  lib_logger.info(
1371
  f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
1372
  )
1373
- max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
 
 
1374
  current_cred = await self.usage_manager.acquire_key(
1375
- available_keys=creds_to_try, model=model, deadline=deadline,
 
 
1376
  max_concurrent=max_concurrent,
1377
- credential_priorities=credential_priorities
1378
  )
1379
  key_acquired = True
1380
  tried_creds.add(current_cred)
@@ -1483,7 +1540,7 @@ class RotatingClient:
1483
  original_exc = getattr(e, "data", e)
1484
  classified_error = classify_error(original_exc)
1485
  error_message = str(original_exc).split("\n")[0]
1486
-
1487
  log_failure(
1488
  api_key=current_cred,
1489
  model=model,
@@ -1493,24 +1550,31 @@ class RotatingClient:
1493
  if request
1494
  else {},
1495
  )
1496
-
1497
  # Record in accumulator for client reporting
1498
- error_accumulator.record_error(current_cred, classified_error, error_message)
1499
-
 
 
1500
  # Check if this error should trigger rotation
1501
  if not should_rotate_on_error(classified_error):
1502
  lib_logger.error(
1503
  f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing."
1504
  )
1505
  raise last_exception
1506
-
1507
  # Handle rate limits with cooldown
1508
- if classified_error.error_type in ["rate_limit", "quota_exceeded"]:
1509
- cooldown_duration = classified_error.retry_after or 60
 
 
 
 
 
1510
  await self.cooldown_manager.start_cooldown(
1511
  provider, cooldown_duration
1512
  )
1513
-
1514
  await self.usage_manager.record_failure(
1515
  current_cred, model, classified_error
1516
  )
@@ -1536,26 +1600,32 @@ class RotatingClient:
1536
  )
1537
  classified_error = classify_error(e)
1538
  error_message = str(e).split("\n")[0]
1539
-
1540
  # Provider-level error: don't increment consecutive failures
1541
  await self.usage_manager.record_failure(
1542
- current_cred, model, classified_error,
1543
- increment_consecutive_failures=False
 
 
1544
  )
1545
 
1546
  if attempt >= self.max_retries - 1:
1547
- error_accumulator.record_error(current_cred, classified_error, error_message)
 
 
1548
  lib_logger.warning(
1549
  f"Cred {mask_credential(current_cred)} failed after max retries. Rotating."
1550
  )
1551
  break
1552
 
1553
  wait_time = classified_error.retry_after or (
1554
- 1 * (2**attempt)
1555
  ) + random.uniform(0, 1)
1556
  remaining_budget = deadline - time.time()
1557
  if wait_time > remaining_budget:
1558
- error_accumulator.record_error(current_cred, classified_error, error_message)
 
 
1559
  lib_logger.warning(
1560
  f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating."
1561
  )
@@ -1580,21 +1650,23 @@ class RotatingClient:
1580
  )
1581
  classified_error = classify_error(e)
1582
  error_message = str(e).split("\n")[0]
1583
-
1584
  # Record in accumulator
1585
- error_accumulator.record_error(current_cred, classified_error, error_message)
1586
-
 
 
1587
  lib_logger.warning(
1588
  f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})."
1589
  )
1590
-
1591
  # Check if this error should trigger rotation
1592
  if not should_rotate_on_error(classified_error):
1593
  lib_logger.error(
1594
  f"Non-recoverable error ({classified_error.error_type}). Failing."
1595
  )
1596
  raise last_exception
1597
-
1598
  await self.usage_manager.record_failure(
1599
  current_cred, model, classified_error
1600
  )
@@ -1616,9 +1688,13 @@ class RotatingClient:
1616
  if provider_instance:
1617
  # Ensure default Gemini safety settings are present (without overriding request)
1618
  try:
1619
- self._apply_default_safety_settings(litellm_kwargs, provider)
 
 
1620
  except Exception:
1621
- lib_logger.debug("Could not apply default safety settings for streaming path; continuing.")
 
 
1622
 
1623
  if "safety_settings" in litellm_kwargs:
1624
  converted_settings = (
@@ -1699,7 +1775,11 @@ class RotatingClient:
1699
  yield chunk
1700
  return
1701
 
1702
- except (StreamedAPIError, litellm.RateLimitError, httpx.HTTPStatusError) as e:
 
 
 
 
1703
  last_exception = e
1704
 
1705
  # This is the final, robust handler for streamed errors.
@@ -1708,7 +1788,7 @@ class RotatingClient:
1708
  # The actual exception might be wrapped in our StreamedAPIError.
1709
  original_exc = getattr(e, "data", e)
1710
  classified_error = classify_error(original_exc)
1711
-
1712
  # Check if this error should trigger rotation
1713
  if not should_rotate_on_error(classified_error):
1714
  lib_logger.error(
@@ -1745,16 +1825,18 @@ class RotatingClient:
1745
  error_message_text = error_details.get(
1746
  "message", str(original_exc).split("\n")[0]
1747
  )
1748
-
1749
  # Record in accumulator for client reporting
1750
- error_accumulator.record_error(current_cred, classified_error, error_message_text)
 
 
1751
 
1752
  if (
1753
  "quota" in error_message_text.lower()
1754
  or "resource_exhausted" in error_status.lower()
1755
  ):
1756
  consecutive_quota_failures += 1
1757
-
1758
  quota_value = "N/A"
1759
  quota_id = "N/A"
1760
  if "details" in error_details and isinstance(
@@ -1764,10 +1846,15 @@ class RotatingClient:
1764
  if isinstance(detail.get("violations"), list):
1765
  for violation in detail["violations"]:
1766
  if "quotaValue" in violation:
1767
- quota_value = violation["quotaValue"]
 
 
1768
  if "quotaId" in violation:
1769
  quota_id = violation["quotaId"]
1770
- if quota_value != "N/A" and quota_id != "N/A":
 
 
 
1771
  break
1772
 
1773
  await self.usage_manager.record_failure(
@@ -1798,8 +1885,13 @@ class RotatingClient:
1798
  f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating."
1799
  )
1800
 
1801
- if classified_error.error_type in ["rate_limit", "quota_exceeded"]:
1802
- cooldown_duration = classified_error.retry_after or 60
 
 
 
 
 
1803
  await self.cooldown_manager.start_cooldown(
1804
  provider, cooldown_duration
1805
  )
@@ -1827,14 +1919,18 @@ class RotatingClient:
1827
  )
1828
  classified_error = classify_error(e)
1829
  error_message_text = str(e).split("\n")[0]
1830
-
1831
- # Record error in accumulator (server errors are abnormal)
1832
- error_accumulator.record_error(current_cred, classified_error, error_message_text)
1833
-
 
 
1834
  # Provider-level error: don't increment consecutive failures
1835
  await self.usage_manager.record_failure(
1836
- current_cred, model, classified_error,
1837
- increment_consecutive_failures=False
 
 
1838
  )
1839
 
1840
  if attempt >= self.max_retries - 1:
@@ -1845,7 +1941,7 @@ class RotatingClient:
1845
  break
1846
 
1847
  wait_time = classified_error.retry_after or (
1848
- 1 * (2**attempt)
1849
  ) + random.uniform(0, 1)
1850
  remaining_budget = deadline - time.time()
1851
  if wait_time > remaining_budget:
@@ -1874,16 +1970,22 @@ class RotatingClient:
1874
  )
1875
  classified_error = classify_error(e)
1876
  error_message_text = str(e).split("\n")[0]
1877
-
1878
  # Record error in accumulator
1879
- error_accumulator.record_error(current_cred, classified_error, error_message_text)
 
 
1880
 
1881
  lib_logger.warning(
1882
  f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}."
1883
  )
1884
 
1885
  # Handle rate limits with cooldown
1886
- if classified_error.status_code == 429 or classified_error.error_type in ["rate_limit", "quota_exceeded"]:
 
 
 
 
1887
  cooldown_duration = classified_error.retry_after or 60
1888
  await self.cooldown_manager.start_cooldown(
1889
  provider, cooldown_duration
@@ -1904,7 +2006,9 @@ class RotatingClient:
1904
  await self.usage_manager.record_failure(
1905
  current_cred, model, classified_error
1906
  )
1907
- lib_logger.info(f"Rotating to next key after {classified_error.error_type} error.")
 
 
1908
  break
1909
 
1910
  finally:
@@ -1913,26 +2017,28 @@ class RotatingClient:
1913
 
1914
  # Build detailed error response using error accumulator
1915
  error_accumulator.timeout_occurred = time.time() >= deadline
1916
- error_accumulator.model = model
1917
- error_accumulator.provider = provider
1918
-
1919
  if error_accumulator.has_errors():
1920
  # Log concise summary for server logs
1921
  lib_logger.error(error_accumulator.build_log_message())
1922
-
1923
  # Build structured error response for client
1924
  error_response = error_accumulator.build_client_error_response()
1925
  error_data = error_response
1926
  else:
1927
  # Fallback if no errors were recorded (shouldn't happen)
1928
- final_error_message = "Request failed: No available API keys after rotation or timeout."
 
 
1929
  if last_exception:
1930
- final_error_message = f"Request failed. Last error: {str(last_exception)}"
 
 
1931
  error_data = {
1932
  "error": {"message": final_error_message, "type": "proxy_error"}
1933
  }
1934
  lib_logger.error(final_error_message)
1935
-
1936
  yield f"data: {json.dumps(error_data)}\n\n"
1937
  yield "data: [DONE]\n\n"
1938
 
@@ -1980,11 +2086,13 @@ class RotatingClient:
1980
  # Handle iflow provider: remove stream_options to avoid HTTP 406
1981
  model = kwargs.get("model", "")
1982
  provider = model.split("/")[0] if "/" in model else ""
1983
-
1984
  if provider == "iflow" and "stream_options" in kwargs:
1985
- lib_logger.debug("Removing stream_options for iflow provider to avoid HTTP 406")
 
 
1986
  kwargs.pop("stream_options", None)
1987
-
1988
  if kwargs.get("stream"):
1989
  # Only add stream_options for providers that support it (excluding iflow)
1990
  if provider != "iflow":
@@ -1992,7 +2100,7 @@ class RotatingClient:
1992
  kwargs["stream_options"] = {}
1993
  if "include_usage" not in kwargs["stream_options"]:
1994
  kwargs["stream_options"]["include_usage"] = True
1995
-
1996
  return self._streaming_acompletion_with_retry(
1997
  request=request, pre_request_callback=pre_request_callback, **kwargs
1998
  )
 
71
  ):
72
  """
73
  Initialize the RotatingClient with intelligent credential rotation.
74
+
75
  Args:
76
  api_keys: Dictionary mapping provider names to lists of API keys
77
  oauth_credentials: Dictionary mapping provider names to OAuth credential paths
 
140
  self.global_timeout = global_timeout
141
  self.abort_on_callback_error = abort_on_callback_error
142
  self.usage_manager = UsageManager(
143
+ file_path=usage_file_path, rotation_tolerance=rotation_tolerance
 
144
  )
145
  self._model_list_cache = {}
146
  self._provider_plugins = PROVIDER_PLUGINS
 
159
  # Validate all values are >= 1
160
  for provider, max_val in self.max_concurrent_requests_per_key.items():
161
  if max_val < 1:
162
+ lib_logger.warning(
163
+ f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1."
164
+ )
165
  self.max_concurrent_requests_per_key[provider] = 1
166
 
167
  def _is_model_ignored(self, provider: str, model_id: str) -> bool:
 
369
 
370
  return kwargs
371
 
372
+ def _apply_default_safety_settings(
373
+ self, litellm_kwargs: Dict[str, Any], provider: str
374
+ ):
375
  """
376
  Ensure default Gemini safety settings are present when calling the Gemini provider.
377
  This will not override any explicit settings provided by the request. It accepts
 
400
  ]
401
 
402
  # If generic form is present, ensure missing generic keys are filled in
403
+ if "safety_settings" in litellm_kwargs and isinstance(
404
+ litellm_kwargs["safety_settings"], dict
405
+ ):
406
  for k, v in default_generic.items():
407
  if k not in litellm_kwargs["safety_settings"]:
408
  litellm_kwargs["safety_settings"][k] = v
409
  return
410
 
411
  # If Gemini form is present, ensure missing gemini categories are appended
412
+ if "safetySettings" in litellm_kwargs and isinstance(
413
+ litellm_kwargs["safetySettings"], list
414
+ ):
415
+ present = {
416
+ item.get("category")
417
+ for item in litellm_kwargs["safetySettings"]
418
+ if isinstance(item, dict)
419
+ }
420
  for d in default_gemini:
421
  if d["category"] not in present:
422
  litellm_kwargs["safetySettings"].append(d)
423
  return
424
 
425
  # Neither present: set generic defaults so provider conversion will translate them
426
+ if (
427
+ "safety_settings" not in litellm_kwargs
428
+ and "safetySettings" not in litellm_kwargs
429
+ ):
430
  litellm_kwargs["safety_settings"] = default_generic.copy()
431
 
432
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
 
444
  """
445
  Lazily initializes and returns a provider instance.
446
  Only initializes providers that have configured credentials.
447
+
448
  Args:
449
  provider_name: The name of the provider to get an instance for.
450
+
451
  Returns:
452
  Provider instance if credentials exist, None otherwise.
453
  """
 
457
  f"Skipping provider '{provider_name}' initialization: no credentials configured"
458
  )
459
  return None
460
+
461
  if provider_name not in self._provider_instances:
462
  if provider_name in self._provider_plugins:
463
  self._provider_instances[provider_name] = self._provider_plugins[
 
479
  def _resolve_model_id(self, model: str, provider: str) -> str:
480
  """
481
  Resolves the actual model ID to send to the provider.
482
+
483
  For custom models with name/ID mappings, returns the ID.
484
  Otherwise, returns the model name unchanged.
485
+
486
  Args:
487
  model: Full model string with provider (e.g., "iflow/DS-v3.2")
488
  provider: Provider name (e.g., "iflow")
489
+
490
  Returns:
491
  Full model string with ID (e.g., "iflow/deepseek-v3.2")
492
  """
493
  # Extract model name from "provider/model_name" format
494
+ model_name = model.split("/")[-1] if "/" in model else model
495
+
496
  # Try to get provider instance to check for model definitions
497
  provider_plugin = self._get_provider_instance(provider)
498
+
499
  # Check if provider has model definitions
500
+ if provider_plugin and hasattr(provider_plugin, "model_definitions"):
501
+ model_id = provider_plugin.model_definitions.get_model_id(
502
+ provider, model_name
503
+ )
504
  if model_id and model_id != model_name:
505
  # Return with provider prefix
506
  return f"{provider}/{model_id}"
507
+
508
  # Fallback: use client's own model definitions
509
  model_id = self.model_definitions.get_model_id(provider, model_name)
510
  if model_id and model_id != model_name:
511
  return f"{provider}/{model_id}"
512
+
513
  # No conversion needed, return original
514
  return model
515
 
 
516
  async def _safe_streaming_wrapper(
517
  self, stream: Any, key: str, model: str, request: Optional[Any] = None
518
  ) -> AsyncGenerator[Any, None]:
519
  """
520
  A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully,
521
  and distinguishes between content and streamed errors.
522
+
523
  FINISH_REASON HANDLING:
524
  Providers just translate chunks - this wrapper handles ALL finish_reason logic:
525
  1. Strip finish_reason from intermediate chunks (litellm defaults to "stop")
 
556
  chunk_dict = chunk.model_dump()
557
  else:
558
  chunk_dict = chunk
559
+
560
  # === FINISH_REASON LOGIC ===
561
  # Providers send raw chunks without finish_reason logic.
562
  # This wrapper determines finish_reason based on accumulated state.
 
564
  choice = chunk_dict["choices"][0]
565
  delta = choice.get("delta", {})
566
  usage = chunk_dict.get("usage", {})
567
+
568
  # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls
569
  if delta.get("tool_calls"):
570
  has_tool_calls = True
571
  accumulated_finish_reason = "tool_calls"
572
+
573
  # Detect final chunk: has usage with completion_tokens > 0
574
  has_completion_tokens = (
575
+ usage
576
+ and isinstance(usage, dict)
577
+ and usage.get("completion_tokens", 0) > 0
578
  )
579
+
580
  if has_completion_tokens:
581
  # FINAL CHUNK: Determine correct finish_reason
582
  if has_tool_calls:
 
592
  # INTERMEDIATE CHUNK: Never emit finish_reason
593
  # (litellm.ModelResponse defaults to "stop" which is wrong)
594
  choice["finish_reason"] = None
595
+
596
  yield f"data: {json.dumps(chunk_dict)}\n\n"
597
 
598
  if hasattr(chunk, "usage") and chunk.usage:
 
741
  # multiple keys have the same usage stats.
742
  credentials_for_provider = list(self.all_credentials[provider])
743
  random.shuffle(credentials_for_provider)
744
+
745
  # Filter out credentials that are unavailable (queued for re-auth)
746
  provider_plugin = self._get_provider_instance(provider)
747
+ if provider_plugin and hasattr(provider_plugin, "is_credential_available"):
748
  available_creds = [
749
+ cred
750
+ for cred in credentials_for_provider
751
  if provider_plugin.is_credential_available(cred)
752
  ]
753
  if available_creds:
 
760
  kwargs = self._convert_model_params(**kwargs)
761
 
762
  # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded.
763
+
764
  # Resolve model ID early, before any credential operations
765
  # This ensures consistent model ID usage for acquisition, release, and tracking
766
  resolved_model = self._resolve_model_id(model, provider)
 
768
  lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'")
769
  model = resolved_model
770
  kwargs["model"] = model # Ensure kwargs has the resolved model for litellm
771
+
772
  # [NEW] Filter by model tier requirement and build priority map
773
  credential_priorities = None
774
+ if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"):
775
  required_tier = provider_plugin.get_model_tier_requirement(model)
776
  if required_tier is not None:
777
  # Filter OUT only credentials we KNOW are too low priority
 
779
  incompatible_creds = []
780
  compatible_creds = []
781
  unknown_creds = []
782
+
783
  for cred in credentials_for_provider:
784
+ if hasattr(provider_plugin, "get_credential_priority"):
785
  priority = provider_plugin.get_credential_priority(cred)
786
  if priority is None:
787
  # Unknown priority - keep it, will be discovered on first use
 
795
  else:
796
  # Provider doesn't support priorities - keep all
797
  unknown_creds.append(cred)
798
+
799
  # If we have any known-compatible or unknown credentials, use them
800
  tier_compatible_creds = compatible_creds + unknown_creds
801
  if tier_compatible_creds:
 
822
  f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. "
823
  f"Request will likely fail."
824
  )
825
+
826
  # Build priority map for usage_manager
827
+ if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
828
  credential_priorities = {}
829
  for cred in credentials_for_provider:
830
  priority = provider_plugin.get_credential_priority(cred)
831
  if priority is not None:
832
  credential_priorities[cred] = priority
833
+
834
  if credential_priorities:
835
  lib_logger.debug(
836
+ f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}"
837
  )
838
 
839
  # Initialize error accumulator for tracking errors across credential rotation
 
877
  )
878
  max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1)
879
  current_cred = await self.usage_manager.acquire_key(
880
+ available_keys=creds_to_try,
881
+ model=model,
882
+ deadline=deadline,
883
  max_concurrent=max_concurrent,
884
+ credential_priorities=credential_priorities,
885
  )
886
  key_acquired = True
887
  tried_creds.add(current_cred)
 
964
  if provider_instance:
965
  # Ensure default Gemini safety settings are present (without overriding request)
966
  try:
967
+ self._apply_default_safety_settings(
968
+ litellm_kwargs, provider
969
+ )
970
  except Exception:
971
  # If anything goes wrong here, avoid breaking the request flow.
972
+ lib_logger.debug(
973
+ "Could not apply default safety settings; continuing."
974
+ )
975
 
976
  if "safety_settings" in litellm_kwargs:
977
  converted_settings = (
 
1054
 
1055
  # Extract a clean error message for the user-facing log
1056
  error_message = str(e).split("\n")[0]
1057
+
1058
  # Record in accumulator for client reporting
1059
+ error_accumulator.record_error(
1060
+ current_cred, classified_error, error_message
1061
+ )
1062
 
1063
  lib_logger.info(
1064
  f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key."
 
1092
  )
1093
  classified_error = classify_error(e)
1094
  error_message = str(e).split("\n")[0]
1095
+
1096
  # Provider-level error: don't increment consecutive failures
1097
  await self.usage_manager.record_failure(
1098
+ current_cred,
1099
+ model,
1100
+ classified_error,
1101
+ increment_consecutive_failures=False,
1102
  )
1103
 
1104
  if attempt >= self.max_retries - 1:
1105
  # Record in accumulator only on final failure for this key
1106
+ error_accumulator.record_error(
1107
+ current_cred, classified_error, error_message
1108
+ )
1109
  lib_logger.warning(
1110
  f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating."
1111
  )
 
1113
 
1114
  # For temporary errors, wait before retrying with the same key.
1115
  wait_time = classified_error.retry_after or (
1116
+ 2**attempt
1117
  ) + random.uniform(0, 1)
1118
  remaining_budget = deadline - time.time()
1119
 
1120
  # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately.
1121
  if wait_time > remaining_budget:
1122
+ error_accumulator.record_error(
1123
+ current_cred, classified_error, error_message
1124
+ )
1125
  lib_logger.warning(
1126
  f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key."
1127
  )
 
1145
  if request
1146
  else {},
1147
  )
1148
+
1149
  classified_error = classify_error(e)
1150
  error_message = str(e).split("\n")[0]
1151
+
 
 
 
1152
  lib_logger.warning(
1153
  f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})."
1154
  )
1155
+
1156
  # Check if this error should trigger rotation
1157
  if not should_rotate_on_error(classified_error):
1158
  lib_logger.error(
1159
  f"Non-recoverable error ({classified_error.error_type}). Failing request."
1160
  )
1161
  raise last_exception
1162
+
1163
+ # Record in accumulator after confirming it's a rotatable error
1164
+ error_accumulator.record_error(
1165
+ current_cred, classified_error, error_message
1166
+ )
1167
+
1168
  # Handle rate limits with cooldown
1169
+ if classified_error.error_type in [
1170
+ "rate_limit",
1171
+ "quota_exceeded",
1172
+ ]:
1173
  cooldown_duration = classified_error.retry_after or 60
1174
  await self.cooldown_manager.start_cooldown(
1175
  provider, cooldown_duration
1176
  )
1177
+
1178
  # Check if we should retry same key (server errors with retries left)
1179
+ if (
1180
+ should_retry_same_key(classified_error)
1181
+ and attempt < self.max_retries - 1
1182
+ ):
1183
+ wait_time = classified_error.retry_after or (
1184
+ 2**attempt
1185
+ ) + random.uniform(0, 1)
1186
  remaining_budget = deadline - time.time()
1187
  if wait_time <= remaining_budget:
1188
  lib_logger.warning(
 
1190
  )
1191
  await asyncio.sleep(wait_time)
1192
  continue
1193
+
1194
  # Record failure and rotate to next key
1195
  await self.usage_manager.record_failure(
1196
  current_cred, model, classified_error
1197
  )
1198
+ lib_logger.info(
1199
+ f"Rotating to next key after {classified_error.error_type} error."
1200
+ )
1201
  break
1202
 
1203
  except Exception as e:
 
1220
 
1221
  classified_error = classify_error(e)
1222
  error_message = str(e).split("\n")[0]
1223
+
 
 
 
1224
  lib_logger.warning(
1225
  f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})."
1226
  )
1227
+
1228
  # Handle rate limits with cooldown
1229
+ if (
1230
+ classified_error.status_code == 429
1231
+ or classified_error.error_type
1232
+ in ["rate_limit", "quota_exceeded"]
1233
+ ):
1234
  cooldown_duration = classified_error.retry_after or 60
1235
  await self.cooldown_manager.start_cooldown(
1236
  provider, cooldown_duration
 
1243
  )
1244
  raise last_exception
1245
 
1246
+ # Record in accumulator after confirming it's a rotatable error
1247
+ error_accumulator.record_error(
1248
+ current_cred, classified_error, error_message
1249
+ )
1250
+
1251
  await self.usage_manager.record_failure(
1252
  current_cred, model, classified_error
1253
  )
 
1259
  # Check if we exhausted all credentials or timed out
1260
  if time.time() >= deadline:
1261
  error_accumulator.timeout_occurred = True
1262
+
1263
  if error_accumulator.has_errors():
1264
  # Log concise summary for server logs
1265
  lib_logger.error(error_accumulator.build_log_message())
1266
+
1267
  # Return the structured error response for the client
1268
  return error_accumulator.build_client_error_response()
1269
 
1270
  # Return None to indicate failure without error details (shouldn't normally happen)
1271
+ lib_logger.warning(
1272
+ "Unexpected state: request failed with no recorded errors. "
1273
+ "This may indicate a logic error in error tracking."
1274
+ )
1275
  return None
1276
 
1277
  async def _streaming_acompletion_with_retry(
 
1287
  # Create a mutable copy of the keys and shuffle it.
1288
  credentials_for_provider = list(self.all_credentials[provider])
1289
  random.shuffle(credentials_for_provider)
1290
+
1291
  # Filter out credentials that are unavailable (queued for re-auth)
1292
  provider_plugin = self._get_provider_instance(provider)
1293
+ if provider_plugin and hasattr(provider_plugin, "is_credential_available"):
1294
  available_creds = [
1295
+ cred
1296
+ for cred in credentials_for_provider
1297
  if provider_plugin.is_credential_available(cred)
1298
  ]
1299
  if available_creds:
 
1315
  lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'")
1316
  model = resolved_model
1317
  kwargs["model"] = model # Ensure kwargs has the resolved model for litellm
1318
+
1319
  # [NEW] Filter by model tier requirement and build priority map
1320
  credential_priorities = None
1321
+ if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"):
1322
  required_tier = provider_plugin.get_model_tier_requirement(model)
1323
  if required_tier is not None:
1324
  # Filter OUT only credentials we KNOW are too low priority
 
1326
  incompatible_creds = []
1327
  compatible_creds = []
1328
  unknown_creds = []
1329
+
1330
  for cred in credentials_for_provider:
1331
+ if hasattr(provider_plugin, "get_credential_priority"):
1332
  priority = provider_plugin.get_credential_priority(cred)
1333
  if priority is None:
1334
  # Unknown priority - keep it, will be discovered on first use
 
1342
  else:
1343
  # Provider doesn't support priorities - keep all
1344
  unknown_creds.append(cred)
1345
+
1346
  # If we have any known-compatible or unknown credentials, use them
1347
  tier_compatible_creds = compatible_creds + unknown_creds
1348
  if tier_compatible_creds:
 
1369
  f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. "
1370
  f"Request will likely fail."
1371
  )
1372
+
1373
  # Build priority map for usage_manager
1374
+ if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
1375
  credential_priorities = {}
1376
  for cred in credentials_for_provider:
1377
  priority = provider_plugin.get_credential_priority(cred)
1378
  if priority is not None:
1379
  credential_priorities[cred] = priority
1380
+
1381
  if credential_priorities:
1382
  lib_logger.debug(
1383
+ f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}"
1384
  )
1385
 
1386
  # Initialize error accumulator for tracking errors across credential rotation
 
1423
  lib_logger.info(
1424
  f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{len(credentials_for_provider)}"
1425
  )
1426
+ max_concurrent = self.max_concurrent_requests_per_key.get(
1427
+ provider, 1
1428
+ )
1429
  current_cred = await self.usage_manager.acquire_key(
1430
+ available_keys=creds_to_try,
1431
+ model=model,
1432
+ deadline=deadline,
1433
  max_concurrent=max_concurrent,
1434
+ credential_priorities=credential_priorities,
1435
  )
1436
  key_acquired = True
1437
  tried_creds.add(current_cred)
 
1540
  original_exc = getattr(e, "data", e)
1541
  classified_error = classify_error(original_exc)
1542
  error_message = str(original_exc).split("\n")[0]
1543
+
1544
  log_failure(
1545
  api_key=current_cred,
1546
  model=model,
 
1550
  if request
1551
  else {},
1552
  )
1553
+
1554
  # Record in accumulator for client reporting
1555
+ error_accumulator.record_error(
1556
+ current_cred, classified_error, error_message
1557
+ )
1558
+
1559
  # Check if this error should trigger rotation
1560
  if not should_rotate_on_error(classified_error):
1561
  lib_logger.error(
1562
  f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing."
1563
  )
1564
  raise last_exception
1565
+
1566
  # Handle rate limits with cooldown
1567
+ if classified_error.error_type in [
1568
+ "rate_limit",
1569
+ "quota_exceeded",
1570
+ ]:
1571
+ cooldown_duration = (
1572
+ classified_error.retry_after or 60
1573
+ )
1574
  await self.cooldown_manager.start_cooldown(
1575
  provider, cooldown_duration
1576
  )
1577
+
1578
  await self.usage_manager.record_failure(
1579
  current_cred, model, classified_error
1580
  )
 
1600
  )
1601
  classified_error = classify_error(e)
1602
  error_message = str(e).split("\n")[0]
1603
+
1604
  # Provider-level error: don't increment consecutive failures
1605
  await self.usage_manager.record_failure(
1606
+ current_cred,
1607
+ model,
1608
+ classified_error,
1609
+ increment_consecutive_failures=False,
1610
  )
1611
 
1612
  if attempt >= self.max_retries - 1:
1613
+ error_accumulator.record_error(
1614
+ current_cred, classified_error, error_message
1615
+ )
1616
  lib_logger.warning(
1617
  f"Cred {mask_credential(current_cred)} failed after max retries. Rotating."
1618
  )
1619
  break
1620
 
1621
  wait_time = classified_error.retry_after or (
1622
+ 2**attempt
1623
  ) + random.uniform(0, 1)
1624
  remaining_budget = deadline - time.time()
1625
  if wait_time > remaining_budget:
1626
+ error_accumulator.record_error(
1627
+ current_cred, classified_error, error_message
1628
+ )
1629
  lib_logger.warning(
1630
  f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating."
1631
  )
 
1650
  )
1651
  classified_error = classify_error(e)
1652
  error_message = str(e).split("\n")[0]
1653
+
1654
  # Record in accumulator
1655
+ error_accumulator.record_error(
1656
+ current_cred, classified_error, error_message
1657
+ )
1658
+
1659
  lib_logger.warning(
1660
  f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})."
1661
  )
1662
+
1663
  # Check if this error should trigger rotation
1664
  if not should_rotate_on_error(classified_error):
1665
  lib_logger.error(
1666
  f"Non-recoverable error ({classified_error.error_type}). Failing."
1667
  )
1668
  raise last_exception
1669
+
1670
  await self.usage_manager.record_failure(
1671
  current_cred, model, classified_error
1672
  )
 
1688
  if provider_instance:
1689
  # Ensure default Gemini safety settings are present (without overriding request)
1690
  try:
1691
+ self._apply_default_safety_settings(
1692
+ litellm_kwargs, provider
1693
+ )
1694
  except Exception:
1695
+ lib_logger.debug(
1696
+ "Could not apply default safety settings for streaming path; continuing."
1697
+ )
1698
 
1699
  if "safety_settings" in litellm_kwargs:
1700
  converted_settings = (
 
1775
  yield chunk
1776
  return
1777
 
1778
+ except (
1779
+ StreamedAPIError,
1780
+ litellm.RateLimitError,
1781
+ httpx.HTTPStatusError,
1782
+ ) as e:
1783
  last_exception = e
1784
 
1785
  # This is the final, robust handler for streamed errors.
 
1788
  # The actual exception might be wrapped in our StreamedAPIError.
1789
  original_exc = getattr(e, "data", e)
1790
  classified_error = classify_error(original_exc)
1791
+
1792
  # Check if this error should trigger rotation
1793
  if not should_rotate_on_error(classified_error):
1794
  lib_logger.error(
 
1825
  error_message_text = error_details.get(
1826
  "message", str(original_exc).split("\n")[0]
1827
  )
1828
+
1829
  # Record in accumulator for client reporting
1830
+ error_accumulator.record_error(
1831
+ current_cred, classified_error, error_message_text
1832
+ )
1833
 
1834
  if (
1835
  "quota" in error_message_text.lower()
1836
  or "resource_exhausted" in error_status.lower()
1837
  ):
1838
  consecutive_quota_failures += 1
1839
+
1840
  quota_value = "N/A"
1841
  quota_id = "N/A"
1842
  if "details" in error_details and isinstance(
 
1846
  if isinstance(detail.get("violations"), list):
1847
  for violation in detail["violations"]:
1848
  if "quotaValue" in violation:
1849
+ quota_value = violation[
1850
+ "quotaValue"
1851
+ ]
1852
  if "quotaId" in violation:
1853
  quota_id = violation["quotaId"]
1854
+ if (
1855
+ quota_value != "N/A"
1856
+ and quota_id != "N/A"
1857
+ ):
1858
  break
1859
 
1860
  await self.usage_manager.record_failure(
 
1885
  f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating."
1886
  )
1887
 
1888
+ if classified_error.error_type in [
1889
+ "rate_limit",
1890
+ "quota_exceeded",
1891
+ ]:
1892
+ cooldown_duration = (
1893
+ classified_error.retry_after or 60
1894
+ )
1895
  await self.cooldown_manager.start_cooldown(
1896
  provider, cooldown_duration
1897
  )
 
1919
  )
1920
  classified_error = classify_error(e)
1921
  error_message_text = str(e).split("\n")[0]
1922
+
1923
+ # Record error in accumulator (server errors are transient, not abnormal)
1924
+ error_accumulator.record_error(
1925
+ current_cred, classified_error, error_message_text
1926
+ )
1927
+
1928
  # Provider-level error: don't increment consecutive failures
1929
  await self.usage_manager.record_failure(
1930
+ current_cred,
1931
+ model,
1932
+ classified_error,
1933
+ increment_consecutive_failures=False,
1934
  )
1935
 
1936
  if attempt >= self.max_retries - 1:
 
1941
  break
1942
 
1943
  wait_time = classified_error.retry_after or (
1944
+ 2**attempt
1945
  ) + random.uniform(0, 1)
1946
  remaining_budget = deadline - time.time()
1947
  if wait_time > remaining_budget:
 
1970
  )
1971
  classified_error = classify_error(e)
1972
  error_message_text = str(e).split("\n")[0]
1973
+
1974
  # Record error in accumulator
1975
+ error_accumulator.record_error(
1976
+ current_cred, classified_error, error_message_text
1977
+ )
1978
 
1979
  lib_logger.warning(
1980
  f"Credential ...{current_cred[-6:]} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}."
1981
  )
1982
 
1983
  # Handle rate limits with cooldown
1984
+ if (
1985
+ classified_error.status_code == 429
1986
+ or classified_error.error_type
1987
+ in ["rate_limit", "quota_exceeded"]
1988
+ ):
1989
  cooldown_duration = classified_error.retry_after or 60
1990
  await self.cooldown_manager.start_cooldown(
1991
  provider, cooldown_duration
 
2006
  await self.usage_manager.record_failure(
2007
  current_cred, model, classified_error
2008
  )
2009
+ lib_logger.info(
2010
+ f"Rotating to next key after {classified_error.error_type} error."
2011
+ )
2012
  break
2013
 
2014
  finally:
 
2017
 
2018
  # Build detailed error response using error accumulator
2019
  error_accumulator.timeout_occurred = time.time() >= deadline
2020
+
 
 
2021
  if error_accumulator.has_errors():
2022
  # Log concise summary for server logs
2023
  lib_logger.error(error_accumulator.build_log_message())
2024
+
2025
  # Build structured error response for client
2026
  error_response = error_accumulator.build_client_error_response()
2027
  error_data = error_response
2028
  else:
2029
  # Fallback if no errors were recorded (shouldn't happen)
2030
+ final_error_message = (
2031
+ "Request failed: No available API keys after rotation or timeout."
2032
+ )
2033
  if last_exception:
2034
+ final_error_message = (
2035
+ f"Request failed. Last error: {str(last_exception)}"
2036
+ )
2037
  error_data = {
2038
  "error": {"message": final_error_message, "type": "proxy_error"}
2039
  }
2040
  lib_logger.error(final_error_message)
2041
+
2042
  yield f"data: {json.dumps(error_data)}\n\n"
2043
  yield "data: [DONE]\n\n"
2044
 
 
2086
  # Handle iflow provider: remove stream_options to avoid HTTP 406
2087
  model = kwargs.get("model", "")
2088
  provider = model.split("/")[0] if "/" in model else ""
2089
+
2090
  if provider == "iflow" and "stream_options" in kwargs:
2091
+ lib_logger.debug(
2092
+ "Removing stream_options for iflow provider to avoid HTTP 406"
2093
+ )
2094
  kwargs.pop("stream_options", None)
2095
+
2096
  if kwargs.get("stream"):
2097
  # Only add stream_options for providers that support it (excluding iflow)
2098
  if provider != "iflow":
 
2100
  kwargs["stream_options"] = {}
2101
  if "include_usage" not in kwargs["stream_options"]:
2102
  kwargs["stream_options"]["include_usage"] = True
2103
+
2104
  return self._streaming_acompletion_with_retry(
2105
  request=request, pre_request_callback=pre_request_callback, **kwargs
2106
  )
src/rotator_library/error_handler.py CHANGED
@@ -1,5 +1,6 @@
1
  import re
2
  import json
 
3
  from typing import Optional, Dict, Any
4
  import httpx
5
 
@@ -20,20 +21,20 @@ from litellm.exceptions import (
20
  def extract_retry_after_from_body(error_body: Optional[str]) -> Optional[int]:
21
  """
22
  Extract the retry-after time from an API error response body.
23
-
24
  Handles various error formats including:
25
  - Gemini CLI: "Your quota will reset after 39s."
26
  - Generic: "quota will reset after 120s", "retry after 60s"
27
-
28
  Args:
29
  error_body: The raw error response body
30
-
31
  Returns:
32
  The retry time in seconds, or None if not found
33
  """
34
  if not error_body:
35
  return None
36
-
37
  # Pattern to match various "reset after Xs" or "retry after Xs" formats
38
  patterns = [
39
  r"quota will reset after\s*(\d+)s",
@@ -41,7 +42,7 @@ def extract_retry_after_from_body(error_body: Optional[str]) -> Optional[int]:
41
  r"retry after\s*(\d+)s",
42
  r"try again in\s*(\d+)\s*seconds?",
43
  ]
44
-
45
  for pattern in patterns:
46
  match = re.search(pattern, error_body, re.IGNORECASE)
47
  if match:
@@ -49,7 +50,7 @@ def extract_retry_after_from_body(error_body: Optional[str]) -> Optional[int]:
49
  return int(match.group(1))
50
  except (ValueError, IndexError):
51
  continue
52
-
53
  return None
54
 
55
 
@@ -70,29 +71,33 @@ class PreRequestCallbackError(Exception):
70
  # =============================================================================
71
 
72
  # Abnormal errors that require attention and should always be reported to client
73
- ABNORMAL_ERROR_TYPES = frozenset({
74
- "forbidden", # 403 - credential access issue
75
- "authentication", # 401 - credential invalid/revoked
76
- "pre_request_callback_error", # Internal proxy error
77
- })
 
 
78
 
79
  # Normal/expected errors during operation - only report if ALL credentials fail
80
- NORMAL_ERROR_TYPES = frozenset({
81
- "rate_limit", # 429 - expected during high load
82
- "quota_exceeded", # Expected when quota runs out
83
- "server_error", # 5xx - transient provider issues
84
- "api_connection", # Network issues - transient
85
- })
 
 
86
 
87
 
88
  def is_abnormal_error(classified_error: "ClassifiedError") -> bool:
89
  """
90
  Check if an error is abnormal and should be reported to the client.
91
-
92
  Abnormal errors indicate credential issues that need attention:
93
  - 403 Forbidden: Credential doesn't have access
94
  - 401 Unauthorized: Credential is invalid/revoked
95
-
96
  Normal errors are expected during operation:
97
  - 429 Rate limit: Expected during high load
98
  - 5xx Server errors: Transient provider issues
@@ -103,11 +108,10 @@ def is_abnormal_error(classified_error: "ClassifiedError") -> bool:
103
  def mask_credential(credential: str) -> str:
104
  """
105
  Mask a credential for safe display in logs and error messages.
106
-
107
  - For API keys: shows last 6 characters (e.g., "...xyz123")
108
  - For OAuth file paths: shows just the filename (e.g., "antigravity_oauth_1.json")
109
  """
110
- import os
111
  if os.path.isfile(credential):
112
  return os.path.basename(credential)
113
  elif len(credential) > 6:
@@ -119,77 +123,79 @@ def mask_credential(credential: str) -> str:
119
  class RequestErrorAccumulator:
120
  """
121
  Tracks errors encountered during a request's credential rotation cycle.
122
-
123
  Used to build informative error messages for clients when all credentials
124
  are exhausted. Distinguishes between abnormal errors (that need attention)
125
  and normal errors (expected during operation).
126
  """
127
-
128
  def __init__(self):
129
  self.abnormal_errors: list = [] # 403, 401 - always report details
130
- self.normal_errors: list = [] # 429, 5xx - summarize only
131
- self.total_credentials_tried: int = 0
132
  self.timeout_occurred: bool = False
133
  self.model: str = ""
134
  self.provider: str = ""
135
-
136
  def record_error(
137
- self,
138
- credential: str,
139
- classified_error: "ClassifiedError",
140
- error_message: str
141
  ):
142
  """Record an error for a credential."""
143
- self.total_credentials_tried += 1
144
  masked_cred = mask_credential(credential)
145
-
146
  error_record = {
147
  "credential": masked_cred,
148
  "error_type": classified_error.error_type,
149
  "status_code": classified_error.status_code,
150
- "message": self._truncate_message(error_message, 150)
151
  }
152
-
153
  if is_abnormal_error(classified_error):
154
  self.abnormal_errors.append(error_record)
155
  else:
156
  self.normal_errors.append(error_record)
157
-
 
 
 
 
 
158
  def _truncate_message(self, message: str, max_length: int = 150) -> str:
159
  """Truncate error message for readability."""
160
  # Take first line and truncate
161
- first_line = message.split('\n')[0]
162
  if len(first_line) > max_length:
163
  return first_line[:max_length] + "..."
164
  return first_line
165
-
166
  def has_errors(self) -> bool:
167
  """Check if any errors were recorded."""
168
  return bool(self.abnormal_errors or self.normal_errors)
169
-
170
  def has_abnormal_errors(self) -> bool:
171
  """Check if any abnormal errors were recorded."""
172
  return bool(self.abnormal_errors)
173
-
174
  def get_normal_error_summary(self) -> str:
175
  """Get a summary of normal errors (not individual details)."""
176
  if not self.normal_errors:
177
  return ""
178
-
179
  # Count by type
180
  counts = {}
181
  for err in self.normal_errors:
182
  err_type = err["error_type"]
183
  counts[err_type] = counts.get(err_type, 0) + 1
184
-
185
  # Build summary like "3 rate_limit, 1 server_error"
186
  parts = [f"{count} {err_type}" for err_type, count in counts.items()]
187
  return ", ".join(parts)
188
-
189
  def build_client_error_response(self) -> dict:
190
  """
191
  Build a structured error response for the client.
192
-
193
  Returns a dict suitable for JSON serialization in the error response.
194
  """
195
  # Determine the primary failure reason
@@ -199,24 +205,34 @@ class RequestErrorAccumulator:
199
  else:
200
  error_type = "proxy_all_credentials_exhausted"
201
  base_message = f"All {self.total_credentials_tried} credential(s) exhausted for {self.provider}"
202
-
203
  # Build human-readable message
204
  message_parts = [base_message]
205
-
206
  if self.abnormal_errors:
207
  message_parts.append("\n\nCredential issues (require attention):")
208
  for err in self.abnormal_errors:
209
- status = f"HTTP {err['status_code']}" if err['status_code'] else err['error_type']
210
- message_parts.append(f"\n • {err['credential']}: {status} - {err['message']}")
211
-
 
 
 
 
 
 
212
  normal_summary = self.get_normal_error_summary()
213
  if normal_summary:
214
  if self.abnormal_errors:
215
- message_parts.append(f"\n\nAdditionally: {normal_summary} (expected during normal operation)")
 
 
216
  else:
217
  message_parts.append(f"\n\nAll failures were: {normal_summary}")
218
- message_parts.append("\nThis is normal during high load - retry later or add more credentials.")
219
-
 
 
220
  response = {
221
  "error": {
222
  "message": "".join(message_parts),
@@ -226,44 +242,48 @@ class RequestErrorAccumulator:
226
  "provider": self.provider,
227
  "credentials_tried": self.total_credentials_tried,
228
  "timeout": self.timeout_occurred,
229
- }
230
  }
231
  }
232
-
233
  # Only include abnormal errors in details (they need attention)
234
  if self.abnormal_errors:
235
  response["error"]["details"]["abnormal_errors"] = self.abnormal_errors
236
-
237
  # Include summary of normal errors
238
  if normal_summary:
239
  response["error"]["details"]["normal_error_summary"] = normal_summary
240
-
241
  return response
242
-
243
  def build_log_message(self) -> str:
244
  """
245
  Build a concise log message for server-side logging.
246
-
247
  Shorter than client message, suitable for terminal display.
248
  """
249
  parts = []
250
-
251
  if self.timeout_occurred:
252
- parts.append(f"TIMEOUT: {self.total_credentials_tried} creds tried for {self.model}")
 
 
253
  else:
254
- parts.append(f"ALL CREDS EXHAUSTED: {self.total_credentials_tried} tried for {self.model}")
255
-
 
 
256
  if self.abnormal_errors:
257
  abnormal_summary = ", ".join(
258
  f"{e['credential']}={e['status_code'] or e['error_type']}"
259
  for e in self.abnormal_errors
260
  )
261
  parts.append(f"ISSUES: {abnormal_summary}")
262
-
263
  normal_summary = self.get_normal_error_summary()
264
  if normal_summary:
265
  parts.append(f"Normal: {normal_summary}")
266
-
267
  return " | ".join(parts)
268
 
269
 
@@ -296,7 +316,7 @@ def get_retry_after(error: Exception) -> Optional[int]:
296
  if isinstance(error, httpx.HTTPStatusError):
297
  headers = error.response.headers
298
  # Check standard Retry-After header (case-insensitive)
299
- retry_header = headers.get('retry-after') or headers.get('Retry-After')
300
  if retry_header:
301
  try:
302
  return int(retry_header) # Assumes seconds format
@@ -304,10 +324,13 @@ def get_retry_after(error: Exception) -> Optional[int]:
304
  pass # Might be HTTP date format, skip for now
305
 
306
  # Check X-RateLimit-Reset header (Unix timestamp)
307
- reset_header = headers.get('x-ratelimit-reset') or headers.get('X-RateLimit-Reset')
 
 
308
  if reset_header:
309
  try:
310
  import time
 
311
  reset_timestamp = int(reset_header)
312
  current_time = int(time.time())
313
  wait_seconds = reset_timestamp - current_time
@@ -357,16 +380,16 @@ def get_retry_after(error: Exception) -> Optional[int]:
357
  continue
358
 
359
  # 3. Handle duration formats like "60s", "2m", "1h"
360
- duration_match = re.search(r'(\d+)\s*([smh])', error_str)
361
  if duration_match:
362
  try:
363
  value = int(duration_match.group(1))
364
  unit = duration_match.group(2)
365
- if unit == 's':
366
  return value
367
- elif unit == 'm':
368
  return value * 60
369
- elif unit == 'h':
370
  return value * 3600
371
  except (ValueError, IndexError):
372
  pass
@@ -381,15 +404,15 @@ def get_retry_after(error: Exception) -> Optional[int]:
381
  if value.isdigit():
382
  return int(value)
383
  # Handle "60s", "2m" format in attribute
384
- duration_match = re.search(r'(\d+)\s*([smh])', value.lower())
385
  if duration_match:
386
  val = int(duration_match.group(1))
387
  unit = duration_match.group(2)
388
- if unit == 's':
389
  return val
390
- elif unit == 'm':
391
  return val * 60
392
- elif unit == 'h':
393
  return val * 3600
394
 
395
  return None
@@ -399,7 +422,7 @@ def classify_error(e: Exception) -> ClassifiedError:
399
  """
400
  Classifies an exception into a structured ClassifiedError object.
401
  Now handles both litellm and httpx exceptions.
402
-
403
  Error types and their typical handling:
404
  - rate_limit (429): Rotate key, may retry with backoff
405
  - server_error (5xx): Retry with backoff, then rotate
@@ -412,16 +435,16 @@ def classify_error(e: Exception) -> ClassifiedError:
412
  - unknown: Rotate key (safer to try another)
413
  """
414
  status_code = getattr(e, "status_code", None)
415
-
416
  if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
417
  status_code = e.response.status_code
418
-
419
  # Try to get error body for better classification
420
  try:
421
- error_body = e.response.text.lower() if hasattr(e.response, 'text') else ""
422
  except Exception:
423
  error_body = ""
424
-
425
  if status_code == 401:
426
  return ClassifiedError(
427
  error_type="authentication",
@@ -453,8 +476,18 @@ def classify_error(e: Exception) -> ClassifiedError:
453
  retry_after=retry_after,
454
  )
455
  if status_code == 400:
456
- # Check for context window / token limit errors
457
- if "context" in error_body or "token" in error_body or "too long" in error_body:
 
 
 
 
 
 
 
 
 
 
458
  return ClassifiedError(
459
  error_type="context_window_exceeded",
460
  original_exception=e,
@@ -465,6 +498,11 @@ def classify_error(e: Exception) -> ClassifiedError:
465
  original_exception=e,
466
  status_code=status_code,
467
  )
 
 
 
 
 
468
  if 400 <= status_code < 500:
469
  # Other 4xx errors - generally client errors
470
  return ClassifiedError(
@@ -567,7 +605,7 @@ def is_unrecoverable_error(e: Exception) -> bool:
567
  def should_rotate_on_error(classified_error: ClassifiedError) -> bool:
568
  """
569
  Determines if an error should trigger key rotation.
570
-
571
  Errors that SHOULD rotate (try another key):
572
  - rate_limit: Current key is throttled
573
  - quota_exceeded: Current key/account exhausted
@@ -576,12 +614,12 @@ def should_rotate_on_error(classified_error: ClassifiedError) -> bool:
576
  - server_error: Provider having issues (might work with different endpoint/key)
577
  - api_connection: Network issues (might be transient)
578
  - unknown: Safer to try another key
579
-
580
  Errors that should NOT rotate (fail immediately):
581
  - invalid_request: Client error in request payload (won't help to retry)
582
  - context_window_exceeded: Request too large (won't help to retry)
583
  - pre_request_callback_error: Internal proxy error
584
-
585
  Returns:
586
  True if should rotate to next key, False if should fail immediately
587
  """
@@ -596,10 +634,10 @@ def should_rotate_on_error(classified_error: ClassifiedError) -> bool:
596
  def should_retry_same_key(classified_error: ClassifiedError) -> bool:
597
  """
598
  Determines if an error should retry with the same key (with backoff).
599
-
600
  Only server errors and connection issues should retry the same key,
601
  as these are often transient.
602
-
603
  Returns:
604
  True if should retry same key, False if should rotate immediately
605
  """
 
1
  import re
2
  import json
3
+ import os
4
  from typing import Optional, Dict, Any
5
  import httpx
6
 
 
21
  def extract_retry_after_from_body(error_body: Optional[str]) -> Optional[int]:
22
  """
23
  Extract the retry-after time from an API error response body.
24
+
25
  Handles various error formats including:
26
  - Gemini CLI: "Your quota will reset after 39s."
27
  - Generic: "quota will reset after 120s", "retry after 60s"
28
+
29
  Args:
30
  error_body: The raw error response body
31
+
32
  Returns:
33
  The retry time in seconds, or None if not found
34
  """
35
  if not error_body:
36
  return None
37
+
38
  # Pattern to match various "reset after Xs" or "retry after Xs" formats
39
  patterns = [
40
  r"quota will reset after\s*(\d+)s",
 
42
  r"retry after\s*(\d+)s",
43
  r"try again in\s*(\d+)\s*seconds?",
44
  ]
45
+
46
  for pattern in patterns:
47
  match = re.search(pattern, error_body, re.IGNORECASE)
48
  if match:
 
50
  return int(match.group(1))
51
  except (ValueError, IndexError):
52
  continue
53
+
54
  return None
55
 
56
 
 
71
  # =============================================================================
72
 
73
  # Abnormal errors that require attention and should always be reported to client
74
+ ABNORMAL_ERROR_TYPES = frozenset(
75
+ {
76
+ "forbidden", # 403 - credential access issue
77
+ "authentication", # 401 - credential invalid/revoked
78
+ "pre_request_callback_error", # Internal proxy error
79
+ }
80
+ )
81
 
82
  # Normal/expected errors during operation - only report if ALL credentials fail
83
+ NORMAL_ERROR_TYPES = frozenset(
84
+ {
85
+ "rate_limit", # 429 - expected during high load
86
+ "quota_exceeded", # Expected when quota runs out
87
+ "server_error", # 5xx - transient provider issues
88
+ "api_connection", # Network issues - transient
89
+ }
90
+ )
91
 
92
 
93
  def is_abnormal_error(classified_error: "ClassifiedError") -> bool:
94
  """
95
  Check if an error is abnormal and should be reported to the client.
96
+
97
  Abnormal errors indicate credential issues that need attention:
98
  - 403 Forbidden: Credential doesn't have access
99
  - 401 Unauthorized: Credential is invalid/revoked
100
+
101
  Normal errors are expected during operation:
102
  - 429 Rate limit: Expected during high load
103
  - 5xx Server errors: Transient provider issues
 
108
  def mask_credential(credential: str) -> str:
109
  """
110
  Mask a credential for safe display in logs and error messages.
111
+
112
  - For API keys: shows last 6 characters (e.g., "...xyz123")
113
  - For OAuth file paths: shows just the filename (e.g., "antigravity_oauth_1.json")
114
  """
 
115
  if os.path.isfile(credential):
116
  return os.path.basename(credential)
117
  elif len(credential) > 6:
 
123
  class RequestErrorAccumulator:
124
  """
125
  Tracks errors encountered during a request's credential rotation cycle.
126
+
127
  Used to build informative error messages for clients when all credentials
128
  are exhausted. Distinguishes between abnormal errors (that need attention)
129
  and normal errors (expected during operation).
130
  """
131
+
132
  def __init__(self):
133
  self.abnormal_errors: list = [] # 403, 401 - always report details
134
+ self.normal_errors: list = [] # 429, 5xx - summarize only
135
+ self._tried_credentials: set = set() # Track unique credentials
136
  self.timeout_occurred: bool = False
137
  self.model: str = ""
138
  self.provider: str = ""
139
+
140
  def record_error(
141
+ self, credential: str, classified_error: "ClassifiedError", error_message: str
 
 
 
142
  ):
143
  """Record an error for a credential."""
144
+ self._tried_credentials.add(credential)
145
  masked_cred = mask_credential(credential)
146
+
147
  error_record = {
148
  "credential": masked_cred,
149
  "error_type": classified_error.error_type,
150
  "status_code": classified_error.status_code,
151
+ "message": self._truncate_message(error_message, 150),
152
  }
153
+
154
  if is_abnormal_error(classified_error):
155
  self.abnormal_errors.append(error_record)
156
  else:
157
  self.normal_errors.append(error_record)
158
+
159
+ @property
160
+ def total_credentials_tried(self) -> int:
161
+ """Return the number of unique credentials tried."""
162
+ return len(self._tried_credentials)
163
+
164
  def _truncate_message(self, message: str, max_length: int = 150) -> str:
165
  """Truncate error message for readability."""
166
  # Take first line and truncate
167
+ first_line = message.split("\n")[0]
168
  if len(first_line) > max_length:
169
  return first_line[:max_length] + "..."
170
  return first_line
171
+
172
  def has_errors(self) -> bool:
173
  """Check if any errors were recorded."""
174
  return bool(self.abnormal_errors or self.normal_errors)
175
+
176
  def has_abnormal_errors(self) -> bool:
177
  """Check if any abnormal errors were recorded."""
178
  return bool(self.abnormal_errors)
179
+
180
  def get_normal_error_summary(self) -> str:
181
  """Get a summary of normal errors (not individual details)."""
182
  if not self.normal_errors:
183
  return ""
184
+
185
  # Count by type
186
  counts = {}
187
  for err in self.normal_errors:
188
  err_type = err["error_type"]
189
  counts[err_type] = counts.get(err_type, 0) + 1
190
+
191
  # Build summary like "3 rate_limit, 1 server_error"
192
  parts = [f"{count} {err_type}" for err_type, count in counts.items()]
193
  return ", ".join(parts)
194
+
195
  def build_client_error_response(self) -> dict:
196
  """
197
  Build a structured error response for the client.
198
+
199
  Returns a dict suitable for JSON serialization in the error response.
200
  """
201
  # Determine the primary failure reason
 
205
  else:
206
  error_type = "proxy_all_credentials_exhausted"
207
  base_message = f"All {self.total_credentials_tried} credential(s) exhausted for {self.provider}"
208
+
209
  # Build human-readable message
210
  message_parts = [base_message]
211
+
212
  if self.abnormal_errors:
213
  message_parts.append("\n\nCredential issues (require attention):")
214
  for err in self.abnormal_errors:
215
+ status = (
216
+ f"HTTP {err['status_code']}"
217
+ if err["status_code"] is not None
218
+ else err["error_type"]
219
+ )
220
+ message_parts.append(
221
+ f"\n • {err['credential']}: {status} - {err['message']}"
222
+ )
223
+
224
  normal_summary = self.get_normal_error_summary()
225
  if normal_summary:
226
  if self.abnormal_errors:
227
+ message_parts.append(
228
+ f"\n\nAdditionally: {normal_summary} (expected during normal operation)"
229
+ )
230
  else:
231
  message_parts.append(f"\n\nAll failures were: {normal_summary}")
232
+ message_parts.append(
233
+ "\nThis is normal during high load - retry later or add more credentials."
234
+ )
235
+
236
  response = {
237
  "error": {
238
  "message": "".join(message_parts),
 
242
  "provider": self.provider,
243
  "credentials_tried": self.total_credentials_tried,
244
  "timeout": self.timeout_occurred,
245
+ },
246
  }
247
  }
248
+
249
  # Only include abnormal errors in details (they need attention)
250
  if self.abnormal_errors:
251
  response["error"]["details"]["abnormal_errors"] = self.abnormal_errors
252
+
253
  # Include summary of normal errors
254
  if normal_summary:
255
  response["error"]["details"]["normal_error_summary"] = normal_summary
256
+
257
  return response
258
+
259
  def build_log_message(self) -> str:
260
  """
261
  Build a concise log message for server-side logging.
262
+
263
  Shorter than client message, suitable for terminal display.
264
  """
265
  parts = []
266
+
267
  if self.timeout_occurred:
268
+ parts.append(
269
+ f"TIMEOUT: {self.total_credentials_tried} creds tried for {self.model}"
270
+ )
271
  else:
272
+ parts.append(
273
+ f"ALL CREDS EXHAUSTED: {self.total_credentials_tried} tried for {self.model}"
274
+ )
275
+
276
  if self.abnormal_errors:
277
  abnormal_summary = ", ".join(
278
  f"{e['credential']}={e['status_code'] or e['error_type']}"
279
  for e in self.abnormal_errors
280
  )
281
  parts.append(f"ISSUES: {abnormal_summary}")
282
+
283
  normal_summary = self.get_normal_error_summary()
284
  if normal_summary:
285
  parts.append(f"Normal: {normal_summary}")
286
+
287
  return " | ".join(parts)
288
 
289
 
 
316
  if isinstance(error, httpx.HTTPStatusError):
317
  headers = error.response.headers
318
  # Check standard Retry-After header (case-insensitive)
319
+ retry_header = headers.get("retry-after") or headers.get("Retry-After")
320
  if retry_header:
321
  try:
322
  return int(retry_header) # Assumes seconds format
 
324
  pass # Might be HTTP date format, skip for now
325
 
326
  # Check X-RateLimit-Reset header (Unix timestamp)
327
+ reset_header = headers.get("x-ratelimit-reset") or headers.get(
328
+ "X-RateLimit-Reset"
329
+ )
330
  if reset_header:
331
  try:
332
  import time
333
+
334
  reset_timestamp = int(reset_header)
335
  current_time = int(time.time())
336
  wait_seconds = reset_timestamp - current_time
 
380
  continue
381
 
382
  # 3. Handle duration formats like "60s", "2m", "1h"
383
+ duration_match = re.search(r"(\d+)\s*([smh])", error_str)
384
  if duration_match:
385
  try:
386
  value = int(duration_match.group(1))
387
  unit = duration_match.group(2)
388
+ if unit == "s":
389
  return value
390
+ elif unit == "m":
391
  return value * 60
392
+ elif unit == "h":
393
  return value * 3600
394
  except (ValueError, IndexError):
395
  pass
 
404
  if value.isdigit():
405
  return int(value)
406
  # Handle "60s", "2m" format in attribute
407
+ duration_match = re.search(r"(\d+)\s*([smh])", value.lower())
408
  if duration_match:
409
  val = int(duration_match.group(1))
410
  unit = duration_match.group(2)
411
+ if unit == "s":
412
  return val
413
+ elif unit == "m":
414
  return val * 60
415
+ elif unit == "h":
416
  return val * 3600
417
 
418
  return None
 
422
  """
423
  Classifies an exception into a structured ClassifiedError object.
424
  Now handles both litellm and httpx exceptions.
425
+
426
  Error types and their typical handling:
427
  - rate_limit (429): Rotate key, may retry with backoff
428
  - server_error (5xx): Retry with backoff, then rotate
 
435
  - unknown: Rotate key (safer to try another)
436
  """
437
  status_code = getattr(e, "status_code", None)
438
+
439
  if isinstance(e, httpx.HTTPStatusError): # [NEW] Handle httpx errors first
440
  status_code = e.response.status_code
441
+
442
  # Try to get error body for better classification
443
  try:
444
+ error_body = e.response.text.lower() if hasattr(e.response, "text") else ""
445
  except Exception:
446
  error_body = ""
447
+
448
  if status_code == 401:
449
  return ClassifiedError(
450
  error_type="authentication",
 
476
  retry_after=retry_after,
477
  )
478
  if status_code == 400:
479
+ # Check for context window / token limit errors with more specific patterns
480
+ if any(
481
+ pattern in error_body
482
+ for pattern in [
483
+ "context_length",
484
+ "max_tokens",
485
+ "token limit",
486
+ "context window",
487
+ "too many tokens",
488
+ "too long",
489
+ ]
490
+ ):
491
  return ClassifiedError(
492
  error_type="context_window_exceeded",
493
  original_exception=e,
 
498
  original_exception=e,
499
  status_code=status_code,
500
  )
501
+ return ClassifiedError(
502
+ error_type="invalid_request",
503
+ original_exception=e,
504
+ status_code=status_code,
505
+ )
506
  if 400 <= status_code < 500:
507
  # Other 4xx errors - generally client errors
508
  return ClassifiedError(
 
605
  def should_rotate_on_error(classified_error: ClassifiedError) -> bool:
606
  """
607
  Determines if an error should trigger key rotation.
608
+
609
  Errors that SHOULD rotate (try another key):
610
  - rate_limit: Current key is throttled
611
  - quota_exceeded: Current key/account exhausted
 
614
  - server_error: Provider having issues (might work with different endpoint/key)
615
  - api_connection: Network issues (might be transient)
616
  - unknown: Safer to try another key
617
+
618
  Errors that should NOT rotate (fail immediately):
619
  - invalid_request: Client error in request payload (won't help to retry)
620
  - context_window_exceeded: Request too large (won't help to retry)
621
  - pre_request_callback_error: Internal proxy error
622
+
623
  Returns:
624
  True if should rotate to next key, False if should fail immediately
625
  """
 
634
  def should_retry_same_key(classified_error: ClassifiedError) -> bool:
635
  """
636
  Determines if an error should retry with the same key (with backoff).
637
+
638
  Only server errors and connection issues should retry the same key,
639
  as these are often transient.
640
+
641
  Returns:
642
  True if should retry same key, False if should rotate immediately
643
  """
src/rotator_library/failure_logger.py CHANGED
@@ -4,6 +4,7 @@ from logging.handlers import RotatingFileHandler
4
  import os
5
  from datetime import datetime
6
 
 
7
  def setup_failure_logger():
8
  """Sets up a dedicated JSON logger for writing detailed failure logs to a file."""
9
  log_dir = "logs"
@@ -12,15 +13,15 @@ def setup_failure_logger():
12
 
13
  # Create a logger specifically for failures.
14
  # This logger will NOT propagate to the root logger.
15
- logger = logging.getLogger('failure_logger')
16
  logger.setLevel(logging.INFO)
17
  logger.propagate = False
18
 
19
  # Use a rotating file handler
20
  handler = RotatingFileHandler(
21
- os.path.join(log_dir, 'failures.log'),
22
- maxBytes=5*1024*1024, # 5 MB
23
- backupCount=2
24
  )
25
 
26
  # Custom JSON formatter for structured logs
@@ -30,62 +31,65 @@ def setup_failure_logger():
30
  return json.dumps(record.msg)
31
 
32
  handler.setFormatter(JsonFormatter())
33
-
34
  # Add handler only if it hasn't been added before
35
  if not logger.handlers:
36
  logger.addHandler(handler)
37
 
38
  return logger
39
 
 
40
  # Initialize the dedicated logger for detailed failure logs
41
  failure_logger = setup_failure_logger()
42
 
43
  # Get the main library logger for concise, propagated messages
44
- main_lib_logger = logging.getLogger('rotator_library')
 
45
 
46
  def _extract_response_body(error: Exception) -> str:
47
  """
48
  Extract the full response body from various error types.
49
-
50
  Handles:
51
  - httpx.HTTPStatusError: response.text or response.content
52
  - litellm exceptions: various response attributes
53
  - Other exceptions: str(error)
54
  """
55
  # Try to get response body from httpx errors
56
- if hasattr(error, 'response') and error.response is not None:
57
  response = error.response
58
  # Try .text first (decoded)
59
- if hasattr(response, 'text') and response.text:
60
  return response.text
61
  # Try .content (bytes)
62
- if hasattr(response, 'content') and response.content:
63
  try:
64
- return response.content.decode('utf-8', errors='replace')
65
  except Exception:
66
  return str(response.content)
67
- # Try reading response if it's a streaming response that was read
68
- if hasattr(response, '_content') and response._content:
69
- try:
70
- return response._content.decode('utf-8', errors='replace')
71
- except Exception:
72
- return str(response._content)
73
-
74
  # Check for litellm's body attribute
75
- if hasattr(error, 'body') and error.body:
76
  return str(error.body)
77
-
78
  # Check for message attribute that might contain response
79
- if hasattr(error, 'message') and error.message:
80
  return str(error.message)
81
-
82
  return None
83
 
84
 
85
- def log_failure(api_key: str, model: str, attempt: int, error: Exception, request_headers: dict, raw_response_text: str = None):
 
 
 
 
 
 
 
86
  """
87
  Logs a detailed failure message to a file and a concise summary to the main logger.
88
-
89
  Args:
90
  api_key: The API key or credential path that was used
91
  model: The model that was requested
@@ -103,19 +107,30 @@ def log_failure(api_key: str, model: str, attempt: int, error: Exception, reques
103
 
104
  # Get full error message (not truncated)
105
  full_error_message = str(error)
106
-
107
  # Also capture any nested/wrapped exception info
108
  error_chain = []
 
109
  current_error = error
110
  while current_error:
111
- error_chain.append({
112
- "type": type(current_error).__name__,
113
- "message": str(current_error)[:2000] # Limit per-error message size
114
- })
115
- current_error = getattr(current_error, '__cause__', None) or getattr(current_error, '__context__', None)
116
- if len(error_chain) > 5: # Prevent infinite loops
117
  break
118
-
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  detailed_log_data = {
120
  "timestamp": datetime.utcnow().isoformat(),
121
  "api_key_ending": api_key[-4:] if len(api_key) >= 4 else "****",
@@ -123,7 +138,9 @@ def log_failure(api_key: str, model: str, attempt: int, error: Exception, reques
123
  "attempt_number": attempt,
124
  "error_type": type(error).__name__,
125
  "error_message": full_error_message[:5000], # Limit total size
126
- "raw_response": raw_response[:10000] if raw_response else None, # Limit response size
 
 
127
  "request_headers": request_headers,
128
  "error_chain": error_chain if len(error_chain) > 1 else None,
129
  }
 
4
  import os
5
  from datetime import datetime
6
 
7
+
8
  def setup_failure_logger():
9
  """Sets up a dedicated JSON logger for writing detailed failure logs to a file."""
10
  log_dir = "logs"
 
13
 
14
  # Create a logger specifically for failures.
15
  # This logger will NOT propagate to the root logger.
16
+ logger = logging.getLogger("failure_logger")
17
  logger.setLevel(logging.INFO)
18
  logger.propagate = False
19
 
20
  # Use a rotating file handler
21
  handler = RotatingFileHandler(
22
+ os.path.join(log_dir, "failures.log"),
23
+ maxBytes=5 * 1024 * 1024, # 5 MB
24
+ backupCount=2,
25
  )
26
 
27
  # Custom JSON formatter for structured logs
 
31
  return json.dumps(record.msg)
32
 
33
  handler.setFormatter(JsonFormatter())
34
+
35
  # Add handler only if it hasn't been added before
36
  if not logger.handlers:
37
  logger.addHandler(handler)
38
 
39
  return logger
40
 
41
+
42
  # Initialize the dedicated logger for detailed failure logs
43
  failure_logger = setup_failure_logger()
44
 
45
  # Get the main library logger for concise, propagated messages
46
+ main_lib_logger = logging.getLogger("rotator_library")
47
+
48
 
49
  def _extract_response_body(error: Exception) -> str:
50
  """
51
  Extract the full response body from various error types.
52
+
53
  Handles:
54
  - httpx.HTTPStatusError: response.text or response.content
55
  - litellm exceptions: various response attributes
56
  - Other exceptions: str(error)
57
  """
58
  # Try to get response body from httpx errors
59
+ if hasattr(error, "response") and error.response is not None:
60
  response = error.response
61
  # Try .text first (decoded)
62
+ if hasattr(response, "text") and response.text:
63
  return response.text
64
  # Try .content (bytes)
65
+ if hasattr(response, "content") and response.content:
66
  try:
67
+ return response.content.decode("utf-8", errors="replace")
68
  except Exception:
69
  return str(response.content)
70
+
 
 
 
 
 
 
71
  # Check for litellm's body attribute
72
+ if hasattr(error, "body") and error.body:
73
  return str(error.body)
74
+
75
  # Check for message attribute that might contain response
76
+ if hasattr(error, "message") and error.message:
77
  return str(error.message)
78
+
79
  return None
80
 
81
 
82
+ def log_failure(
83
+ api_key: str,
84
+ model: str,
85
+ attempt: int,
86
+ error: Exception,
87
+ request_headers: dict,
88
+ raw_response_text: str = None,
89
+ ):
90
  """
91
  Logs a detailed failure message to a file and a concise summary to the main logger.
92
+
93
  Args:
94
  api_key: The API key or credential path that was used
95
  model: The model that was requested
 
107
 
108
  # Get full error message (not truncated)
109
  full_error_message = str(error)
110
+
111
  # Also capture any nested/wrapped exception info
112
  error_chain = []
113
+ visited = set() # Track visited exceptions to detect circular references
114
  current_error = error
115
  while current_error:
116
+ # Check for circular references
117
+ error_id = id(current_error)
118
+ if error_id in visited:
 
 
 
119
  break
120
+ visited.add(error_id)
121
+
122
+ error_chain.append(
123
+ {
124
+ "type": type(current_error).__name__,
125
+ "message": str(current_error)[:2000], # Limit per-error message size
126
+ }
127
+ )
128
+ current_error = getattr(current_error, "__cause__", None) or getattr(
129
+ current_error, "__context__", None
130
+ )
131
+ if len(error_chain) > 5: # Prevent excessive chain length
132
+ break
133
+
134
  detailed_log_data = {
135
  "timestamp": datetime.utcnow().isoformat(),
136
  "api_key_ending": api_key[-4:] if len(api_key) >= 4 else "****",
 
138
  "attempt_number": attempt,
139
  "error_type": type(error).__name__,
140
  "error_message": full_error_message[:5000], # Limit total size
141
+ "raw_response": raw_response[:10000]
142
+ if raw_response
143
+ else None, # Limit response size
144
  "request_headers": request_headers,
145
  "error_chain": error_chain if len(error_chain) > 1 else None,
146
  }