Mirrowel commited on
Commit
39e01ca
·
1 Parent(s): 225c46e

feat(core): enable static model configuration and LiteLLM compatibility for custom providers

Browse files

introduce static model definitions and runtime parameter conversion for custom endpoints.

This change significantly improves compatibility with self-hosted or dynamically configured OpenAI-compatible APIs:

- **Model Definitions:** Adds a new `ModelDefinitions` utility to load static model configurations (IDs, options like `reasoning_effort`) from environment variables (e.g., `PROVIDER_MODELS`).
- **Dynamic Providers:** Extends provider discovery to dynamically register new providers whenever an `_API_BASE` environment variable is detected.
- **Client Conversion:** Implements `_convert_model_params_for_litellm` in the RotatingClient to rewrite the model argument (to `openai/{model_id}`) and inject the necessary `api_base` and `custom_llm_provider` kwargs right before calling LiteLLM.
- **Option Application:** Ensures that model options loaded from the static definitions are merged into the LiteLLM request arguments.
- **Cost Management:** Configures custom providers to skip cost calculation in the `UsageManager`.

src/rotator_library/client.py CHANGED
@@ -297,6 +297,32 @@ class RotatingClient:
297
 
298
  return kwargs
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
301
  return self.oauth_credentials
302
 
@@ -566,6 +592,18 @@ class RotatingClient:
566
  }
567
 
568
  provider_plugin = self._get_provider_instance(provider)
 
 
 
 
 
 
 
 
 
 
 
 
569
  if provider_plugin and provider_plugin.has_custom_logic():
570
  lib_logger.debug(
571
  f"Provider '{provider}' has custom logic. Delegating call."
@@ -666,8 +704,13 @@ class RotatingClient:
666
  f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
667
  )
668
 
 
 
 
 
 
669
  response = await api_call(
670
- **litellm_kwargs,
671
  logger_fn=self._litellm_logger_callback,
672
  )
673
 
@@ -912,6 +955,19 @@ class RotatingClient:
912
  }
913
 
914
  provider_plugin = self._get_provider_instance(provider)
 
 
 
 
 
 
 
 
 
 
 
 
 
915
  if provider_plugin and provider_plugin.has_custom_logic():
916
  lib_logger.debug(
917
  f"Provider '{provider}' has custom logic. Delegating call."
@@ -1121,8 +1177,13 @@ class RotatingClient:
1121
  )
1122
 
1123
  # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}")
 
 
 
 
 
1124
  response = await litellm.acompletion(
1125
- **litellm_kwargs,
1126
  logger_fn=self._litellm_logger_callback,
1127
  )
1128
 
 
297
 
298
  return kwargs
299
 
300
+ def _convert_model_params_for_litellm(self, **kwargs) -> Dict[str, Any]:
301
+ """
302
+ Converts model parameters specifically for LiteLLM calls.
303
+ This is called right before calling LiteLLM to handle custom providers.
304
+ """
305
+ model = kwargs.get("model")
306
+ if not model:
307
+ return kwargs
308
+
309
+ provider = model.split("/")[0]
310
+
311
+ # Handle custom OpenAI-compatible providers
312
+ # Check if this is a custom provider by looking for API_BASE environment variable
313
+ import os
314
+
315
+ api_base_env = f"{provider.upper()}_API_BASE"
316
+ if os.getenv(api_base_env):
317
+ # For custom providers, tell LiteLLM to use openai provider with custom model name
318
+ # This preserves original model name in logs but converts for LiteLLM
319
+ kwargs = kwargs.copy() # Don't modify original
320
+ kwargs["model"] = f"openai/{model.split('/', 1)[1]}"
321
+ kwargs["api_base"] = os.getenv(api_base_env).rstrip("/")
322
+ kwargs["custom_llm_provider"] = "openai"
323
+
324
+ return kwargs
325
+
326
  def get_oauth_credentials(self) -> Dict[str, List[str]]:
327
  return self.oauth_credentials
328
 
 
592
  }
593
 
594
  provider_plugin = self._get_provider_instance(provider)
595
+
596
+ # Apply model-specific options for custom providers
597
+ if provider_plugin and hasattr(provider_plugin, "get_model_options"):
598
+ model_options = provider_plugin.get_model_options(model)
599
+ if model_options:
600
+ # Merge model options into litellm_kwargs
601
+ for key, value in model_options.items():
602
+ if key == "reasoning_effort":
603
+ litellm_kwargs["reasoning_effort"] = value
604
+ elif key not in litellm_kwargs:
605
+ litellm_kwargs[key] = value
606
+
607
  if provider_plugin and provider_plugin.has_custom_logic():
608
  lib_logger.debug(
609
  f"Provider '{provider}' has custom logic. Delegating call."
 
704
  f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}"
705
  )
706
 
707
+ # Convert model parameters for custom providers right before LiteLLM call
708
+ final_kwargs = self._convert_model_params_for_litellm(
709
+ **litellm_kwargs
710
+ )
711
+
712
  response = await api_call(
713
+ **final_kwargs,
714
  logger_fn=self._litellm_logger_callback,
715
  )
716
 
 
955
  }
956
 
957
  provider_plugin = self._get_provider_instance(provider)
958
+
959
+ # Apply model-specific options for custom providers
960
+ if provider_plugin and hasattr(
961
+ provider_plugin, "get_model_options"
962
+ ):
963
+ model_options = provider_plugin.get_model_options(model)
964
+ if model_options:
965
+ # Merge model options into litellm_kwargs
966
+ for key, value in model_options.items():
967
+ if key == "reasoning_effort":
968
+ litellm_kwargs["reasoning_effort"] = value
969
+ elif key not in litellm_kwargs:
970
+ litellm_kwargs[key] = value
971
  if provider_plugin and provider_plugin.has_custom_logic():
972
  lib_logger.debug(
973
  f"Provider '{provider}' has custom logic. Delegating call."
 
1177
  )
1178
 
1179
  # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}")
1180
+ # Convert model parameters for custom providers right before LiteLLM call
1181
+ final_kwargs = self._convert_model_params_for_litellm(
1182
+ **litellm_kwargs
1183
+ )
1184
+
1185
  response = await litellm.acompletion(
1186
+ **final_kwargs,
1187
  logger_fn=self._litellm_logger_callback,
1188
  )
1189
 
src/rotator_library/error_handler.py CHANGED
@@ -269,7 +269,7 @@ class AllProviders:
269
  api_base = os.getenv(env_var)
270
  if api_base:
271
  self.providers[provider_name] = {
272
- "api_base": api_base.rstrip("/") if api_base else None,
273
  "model_prefix": None, # No prefix for custom providers
274
  }
275
 
 
269
  api_base = os.getenv(env_var)
270
  if api_base:
271
  self.providers[provider_name] = {
272
+ "api_base": api_base.rstrip("/") if api_base else "",
273
  "model_prefix": None, # No prefix for custom providers
274
  }
275
 
src/rotator_library/model_definitions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import logging
4
+ from typing import Dict, Any, Optional
5
+
6
+ lib_logger = logging.getLogger("rotator_library")
7
+ lib_logger.propagate = False
8
+ if not lib_logger.handlers:
9
+ lib_logger.addHandler(logging.NullHandler())
10
+
11
+
12
+ class ModelDefinitions:
13
+ """
14
+ Simple model definitions loader from environment variables.
15
+ Format: PROVIDER_MODELS={"model1": {"id": "id1"}, "model2": {"id": "id2", "options": {"reasoning_effort": "high"}}}
16
+ """
17
+
18
+ def __init__(self, config_path: Optional[str] = None):
19
+ """Initialize model definitions loader."""
20
+ self.config_path = config_path
21
+ self.definitions = {}
22
+ self._load_definitions()
23
+
24
+ def _load_definitions(self):
25
+ """Load model definitions from environment variables."""
26
+ for env_var, env_value in os.environ.items():
27
+ if env_var.endswith("_MODELS"):
28
+ provider_name = env_var[:-7].lower() # Remove "_MODELS" (7 characters)
29
+ try:
30
+ models_json = json.loads(env_value)
31
+ if isinstance(models_json, dict):
32
+ self.definitions[provider_name] = models_json
33
+ lib_logger.info(
34
+ f"Loaded {len(models_json)} models for provider: {provider_name}"
35
+ )
36
+ except (json.JSONDecodeError, TypeError) as e:
37
+ lib_logger.warning(f"Invalid JSON in {env_var}: {e}")
38
+
39
+ def get_provider_models(self, provider_name: str) -> Dict[str, Any]:
40
+ """Get all models for a provider."""
41
+ return self.definitions.get(provider_name, {})
42
+
43
+ def get_model_definition(
44
+ self, provider_name: str, model_name: str
45
+ ) -> Optional[Dict[str, Any]]:
46
+ """Get a specific model definition."""
47
+ provider_models = self.get_provider_models(provider_name)
48
+ return provider_models.get(model_name)
49
+
50
+ def get_model_options(self, provider_name: str, model_name: str) -> Dict[str, Any]:
51
+ """Get options for a specific model."""
52
+ model_def = self.get_model_definition(provider_name, model_name)
53
+ return model_def.get("options", {}) if model_def else {}
54
+
55
+ def get_model_id(self, provider_name: str, model_name: str) -> Optional[str]:
56
+ """Get model ID for a specific model."""
57
+ model_def = self.get_model_definition(provider_name, model_name)
58
+ return model_def.get("id") if model_def else None
59
+
60
+ def get_all_provider_models(self, provider_name: str) -> list:
61
+ """Get all model names with provider prefix."""
62
+ provider_models = self.get_provider_models(provider_name)
63
+ return [f"{provider_name}/{model}" for model in provider_models.keys()]
64
+
65
+ def reload_definitions(self):
66
+ """Reload model definitions from environment variables."""
67
+ self.definitions.clear()
68
+ self._load_definitions()
src/rotator_library/providers/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  import importlib
2
  import pkgutil
 
3
  from typing import Dict, Type
4
  from .provider_interface import ProviderInterface
5
 
@@ -8,31 +9,126 @@ from .provider_interface import ProviderInterface
8
  # Dictionary to hold discovered provider classes, mapping provider name to class
9
  PROVIDER_PLUGINS: Dict[str, Type[ProviderInterface]] = {}
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def _register_providers():
12
  """
13
  Dynamically discovers and imports provider plugins from this directory.
 
14
  """
15
  package_path = __path__
16
  package_name = __name__
17
 
 
18
  for _, module_name, _ in pkgutil.iter_modules(package_path):
19
  # Construct the full module path
20
  full_module_path = f"{package_name}.{module_name}"
21
-
22
  # Import the module
23
  module = importlib.import_module(full_module_path)
24
 
25
  # Look for a class that inherits from ProviderInterface
26
  for attribute_name in dir(module):
27
  attribute = getattr(module, attribute_name)
28
- if isinstance(attribute, type) and issubclass(attribute, ProviderInterface) and attribute is not ProviderInterface:
 
 
 
 
29
  # Derives 'gemini_cli' from 'gemini_cli_provider.py'
30
  # Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
31
  provider_name = module_name.replace("_provider", "")
32
  if provider_name == "nvidia":
33
  provider_name = "nvidia_nim"
34
  PROVIDER_PLUGINS[provider_name] = attribute
35
- #print(f"Registered provider: {provider_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Discover and register providers when the package is imported
38
  _register_providers()
 
1
  import importlib
2
  import pkgutil
3
+ import os
4
  from typing import Dict, Type
5
  from .provider_interface import ProviderInterface
6
 
 
9
  # Dictionary to hold discovered provider classes, mapping provider name to class
10
  PROVIDER_PLUGINS: Dict[str, Type[ProviderInterface]] = {}
11
 
12
+
13
+ class DynamicOpenAICompatibleProvider:
14
+ """
15
+ Dynamic provider class for custom OpenAI-compatible providers.
16
+ Created at runtime for providers with API_BASE environment variables.
17
+ """
18
+
19
+ def __init__(self, provider_name: str):
20
+ self.provider_name = provider_name
21
+ # Get API base URL from environment
22
+ self.api_base = os.getenv(f"{provider_name.upper()}_API_BASE")
23
+ if not self.api_base:
24
+ raise ValueError(
25
+ f"Environment variable {provider_name.upper()}_API_BASE is required for OpenAI-compatible provider"
26
+ )
27
+
28
+ # Import model definitions
29
+ from ..model_definitions import ModelDefinitions
30
+
31
+ self.model_definitions = ModelDefinitions()
32
+
33
+ def skip_cost_calculation(self) -> bool:
34
+ """Custom providers should skip cost calculation."""
35
+ return True
36
+
37
+ def get_models(self, api_key: str, client):
38
+ """Delegate to OpenAI-compatible provider implementation."""
39
+ from .openai_compatible_provider import OpenAICompatibleProvider
40
+
41
+ # Create temporary instance to reuse logic
42
+ temp_provider = OpenAICompatibleProvider(self.provider_name)
43
+ return temp_provider.get_models(api_key, client)
44
+
45
+ def get_model_options(self, model_name: str) -> Dict[str, any]:
46
+ """Get model options from static definitions."""
47
+ # Extract model name without provider prefix if present
48
+ if "/" in model_name:
49
+ model_name = model_name.split("/")[-1]
50
+
51
+ return self.model_definitions.get_model_options(self.provider_name, model_name)
52
+
53
+ def has_custom_logic(self) -> bool:
54
+ """Returns False since we want to use the standard litellm flow."""
55
+ return False
56
+
57
+ def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
58
+ """Returns the standard Bearer token header."""
59
+ return {"Authorization": f"Bearer {credential_identifier}"}
60
+
61
+
62
  def _register_providers():
63
  """
64
  Dynamically discovers and imports provider plugins from this directory.
65
+ Also creates dynamic plugins for custom OpenAI-compatible providers.
66
  """
67
  package_path = __path__
68
  package_name = __name__
69
 
70
+ # First, register file-based providers
71
  for _, module_name, _ in pkgutil.iter_modules(package_path):
72
  # Construct the full module path
73
  full_module_path = f"{package_name}.{module_name}"
74
+
75
  # Import the module
76
  module = importlib.import_module(full_module_path)
77
 
78
  # Look for a class that inherits from ProviderInterface
79
  for attribute_name in dir(module):
80
  attribute = getattr(module, attribute_name)
81
+ if (
82
+ isinstance(attribute, type)
83
+ and issubclass(attribute, ProviderInterface)
84
+ and attribute is not ProviderInterface
85
+ ):
86
  # Derives 'gemini_cli' from 'gemini_cli_provider.py'
87
  # Remap 'nvidia' to 'nvidia_nim' to align with litellm's provider name
88
  provider_name = module_name.replace("_provider", "")
89
  if provider_name == "nvidia":
90
  provider_name = "nvidia_nim"
91
  PROVIDER_PLUGINS[provider_name] = attribute
92
+ # print(f"Registered provider: {provider_name}")
93
+
94
+ # Then, create dynamic plugins for custom OpenAI-compatible providers
95
+ # Load environment variables to find custom providers
96
+ from dotenv import load_dotenv
97
+
98
+ load_dotenv()
99
+
100
+ for env_var in os.environ:
101
+ if env_var.endswith("_API_BASE"):
102
+ provider_name = env_var[:-9].lower() # Remove '_API_BASE' suffix
103
+
104
+ # Skip known providers that already have file-based plugins
105
+ if provider_name in [
106
+ "openai",
107
+ "anthropic",
108
+ "google",
109
+ "gemini",
110
+ "nvidia",
111
+ "mistral",
112
+ "cohere",
113
+ "groq",
114
+ "openrouter",
115
+ "chutes",
116
+ ]:
117
+ continue
118
+
119
+ # Create a dynamic plugin class
120
+ def create_plugin_class(name):
121
+ class DynamicPlugin(DynamicOpenAICompatibleProvider):
122
+ def __init__(self):
123
+ super().__init__(name)
124
+
125
+ return DynamicPlugin
126
+
127
+ # Create and register the plugin class
128
+ plugin_class = create_plugin_class(provider_name)
129
+ PROVIDER_PLUGINS[provider_name] = plugin_class
130
+ # print(f"Registered dynamic provider: {provider_name}")
131
+
132
 
133
  # Discover and register providers when the package is imported
134
  _register_providers()
src/rotator_library/providers/openai_compatible_provider.py CHANGED
@@ -3,6 +3,7 @@ import httpx
3
  import logging
4
  from typing import List, Dict, Any, Optional
5
  from .provider_interface import ProviderInterface
 
6
 
7
  lib_logger = logging.getLogger("rotator_library")
8
  lib_logger.propagate = False
@@ -15,7 +16,11 @@ class OpenAICompatibleProvider(ProviderInterface):
15
  Generic provider implementation for any OpenAI-compatible API.
16
  This provider can be configured via environment variables to support
17
  custom OpenAI-compatible endpoints without requiring code changes.
 
18
  """
 
 
 
19
 
20
  def __init__(self, provider_name: str):
21
  self.provider_name = provider_name
@@ -26,28 +31,70 @@ class OpenAICompatibleProvider(ProviderInterface):
26
  f"Environment variable {provider_name.upper()}_API_BASE is required for OpenAI-compatible provider"
27
  )
28
 
 
 
 
29
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
30
  """
31
  Fetches the list of available models from the OpenAI-compatible API.
 
32
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
  models_url = f"{self.api_base.rstrip('/')}/models"
35
  response = await client.get(
36
  models_url, headers={"Authorization": f"Bearer {api_key}"}
37
  )
38
  response.raise_for_status()
39
- return [
 
40
  f"{self.provider_name}/{model['id']}"
41
  for model in response.json().get("data", [])
 
42
  ]
43
- except httpx.RequestError as e:
44
- lib_logger.error(f"Failed to fetch models for {self.provider_name}: {e}")
45
- return []
46
- except Exception as e:
47
- lib_logger.error(
48
- f"Unexpected error fetching models for {self.provider_name}: {e}"
49
- )
50
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def has_custom_logic(self) -> bool:
53
  """
 
3
  import logging
4
  from typing import List, Dict, Any, Optional
5
  from .provider_interface import ProviderInterface
6
+ from ..model_definitions import ModelDefinitions
7
 
8
  lib_logger = logging.getLogger("rotator_library")
9
  lib_logger.propagate = False
 
16
  Generic provider implementation for any OpenAI-compatible API.
17
  This provider can be configured via environment variables to support
18
  custom OpenAI-compatible endpoints without requiring code changes.
19
+ Supports both dynamic model discovery and static model definitions.
20
  """
21
+
22
+ skip_cost_calculation: bool = True # Skip cost calculation for custom providers
23
+
24
 
25
  def __init__(self, provider_name: str):
26
  self.provider_name = provider_name
 
31
  f"Environment variable {provider_name.upper()}_API_BASE is required for OpenAI-compatible provider"
32
  )
33
 
34
+ # Initialize model definitions loader
35
+ self.model_definitions = ModelDefinitions()
36
+
37
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
38
  """
39
  Fetches the list of available models from the OpenAI-compatible API.
40
+ Combines dynamic discovery with static model definitions.
41
  """
42
+ models = []
43
+
44
+ # First, try to get static model definitions
45
+ static_models = self.model_definitions.get_all_provider_models(
46
+ self.provider_name
47
+ )
48
+ if static_models:
49
+ models.extend(static_models)
50
+ lib_logger.info(
51
+ f"Loaded {len(static_models)} static models for {self.provider_name}"
52
+ )
53
+
54
+ # Then, try dynamic discovery to get additional models
55
  try:
56
  models_url = f"{self.api_base.rstrip('/')}/models"
57
  response = await client.get(
58
  models_url, headers={"Authorization": f"Bearer {api_key}"}
59
  )
60
  response.raise_for_status()
61
+
62
+ dynamic_models = [
63
  f"{self.provider_name}/{model['id']}"
64
  for model in response.json().get("data", [])
65
+ if model["id"] not in [m.split("/")[-1] for m in static_models]
66
  ]
67
+
68
+ if dynamic_models:
69
+ models.extend(dynamic_models)
70
+ lib_logger.debug(
71
+ f"Discovered {len(dynamic_models)} additional models for {self.provider_name}"
72
+ )
73
+
74
+ except httpx.RequestError:
75
+ # Silently ignore dynamic discovery errors
76
+ pass
77
+ except Exception:
78
+ # Silently ignore dynamic discovery errors
79
+ pass
80
+
81
+ return models
82
+
83
+ def get_model_options(self, model_name: str) -> Dict[str, Any]:
84
+ """
85
+ Get options for a specific model from static definitions or environment variables.
86
+
87
+ Args:
88
+ model_name: Model name (without provider prefix)
89
+
90
+ Returns:
91
+ Dictionary of model options
92
+ """
93
+ # Extract model name without provider prefix if present
94
+ if "/" in model_name:
95
+ model_name = model_name.split("/")[-1]
96
+
97
+ return self.model_definitions.get_model_options(self.provider_name, model_name)
98
 
99
  def has_custom_logic(self) -> bool:
100
  """
src/rotator_library/usage_manager.py CHANGED
@@ -11,20 +11,26 @@ import litellm
11
  from .error_handler import ClassifiedError, NoAvailableKeysError
12
  from .providers import PROVIDER_PLUGINS
13
 
14
- lib_logger = logging.getLogger('rotator_library')
15
  lib_logger.propagate = False
16
  if not lib_logger.handlers:
17
  lib_logger.addHandler(logging.NullHandler())
18
 
 
19
  class UsageManager:
20
  """
21
  Manages usage statistics and cooldowns for API keys with asyncio-safe locking,
22
  asynchronous file I/O, and a lazy-loading mechanism for usage data.
23
  """
24
- def __init__(self, file_path: str = "key_usage.json", daily_reset_time_utc: Optional[str] = "03:00"):
 
 
 
 
 
25
  self.file_path = file_path
26
  self.key_states: Dict[str, Dict[str, Any]] = {}
27
-
28
  self._data_lock = asyncio.Lock()
29
  self._usage_data: Optional[Dict] = None
30
  self._initialized = asyncio.Event()
@@ -34,8 +40,10 @@ class UsageManager:
34
  self._claimed_on_timeout: Set[str] = set()
35
 
36
  if daily_reset_time_utc:
37
- hour, minute = map(int, daily_reset_time_utc.split(':'))
38
- self.daily_reset_time_utc = dt_time(hour=hour, minute=minute, tzinfo=timezone.utc)
 
 
39
  else:
40
  self.daily_reset_time_utc = None
41
 
@@ -54,7 +62,7 @@ class UsageManager:
54
  self._usage_data = {}
55
  return
56
  try:
57
- async with aiofiles.open(self.file_path, 'r') as f:
58
  content = await f.read()
59
  self._usage_data = json.loads(content)
60
  except (json.JSONDecodeError, IOError, FileNotFoundError):
@@ -65,7 +73,7 @@ class UsageManager:
65
  if self._usage_data is None:
66
  return
67
  async with self._data_lock:
68
- async with aiofiles.open(self.file_path, 'w') as f:
69
  await f.write(json.dumps(self._usage_data, indent=2))
70
 
71
  async def _reset_daily_stats_if_needed(self):
@@ -79,24 +87,31 @@ class UsageManager:
79
 
80
  for key, data in self._usage_data.items():
81
  last_reset_str = data.get("last_daily_reset", "")
82
-
83
  if last_reset_str != today_str:
84
  last_reset_dt = None
85
  if last_reset_str:
86
  # Ensure the parsed datetime is timezone-aware (UTC)
87
- last_reset_dt = datetime.fromisoformat(last_reset_str).replace(tzinfo=timezone.utc)
 
 
88
 
89
  # Determine the reset threshold for today
90
- reset_threshold_today = datetime.combine(now_utc.date(), self.daily_reset_time_utc)
91
-
92
- if last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc:
 
 
 
 
 
93
  lib_logger.info(f"Performing daily reset for key ...{key[-6:]}")
94
  needs_saving = True
95
-
96
  # Reset cooldowns
97
  data["model_cooldowns"] = {}
98
  data["key_cooldown_until"] = None
99
-
100
  # Reset consecutive failures
101
  if "failures" in data:
102
  data["failures"] = {}
@@ -106,12 +121,28 @@ class UsageManager:
106
  if daily_data:
107
  global_data = data.setdefault("global", {"models": {}})
108
  for model, stats in daily_data.get("models", {}).items():
109
- global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
110
- global_model_stats["success_count"] += stats.get("success_count", 0)
111
- global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
112
- global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
113
- global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0)
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Reset daily stats
116
  data["daily"] = {"date": today_str, "models": {}}
117
  data["last_daily_reset"] = today_str
@@ -126,10 +157,12 @@ class UsageManager:
126
  self.key_states[key] = {
127
  "lock": asyncio.Lock(),
128
  "condition": asyncio.Condition(),
129
- "models_in_use": set()
130
  }
131
 
132
- async def acquire_key(self, available_keys: List[str], model: str, deadline: float) -> str:
 
 
133
  """
134
  Acquires the best available key using a tiered, model-aware locking strategy,
135
  respecting a global deadline.
@@ -142,18 +175,24 @@ class UsageManager:
142
  while time.time() < deadline:
143
  tier1_keys, tier2_keys = [], []
144
  now = time.time()
145
-
146
  # First, filter the list of available keys to exclude any on cooldown.
147
  async with self._data_lock:
148
  for key in available_keys:
149
  key_data = self._usage_data.get(key, {})
150
-
151
- if (key_data.get("key_cooldown_until") or 0) > now or \
152
- (key_data.get("model_cooldowns", {}).get(model) or 0) > now:
 
153
  continue
154
 
155
  # Prioritize keys based on their current usage to ensure load balancing.
156
- usage_count = key_data.get("daily", {}).get("models", {}).get(model, {}).get("success_count", 0)
 
 
 
 
 
157
  key_state = self.key_states[key]
158
 
159
  # Tier 1: Completely idle keys (preferred).
@@ -172,7 +211,9 @@ class UsageManager:
172
  async with state["lock"]:
173
  if not state["models_in_use"]:
174
  state["models_in_use"].add(model)
175
- lib_logger.info(f"Acquired Tier 1 key ...{key[-6:]} for model {model}")
 
 
176
  return key
177
 
178
  # If no Tier 1 keys are available, try Tier 2.
@@ -181,37 +222,46 @@ class UsageManager:
181
  async with state["lock"]:
182
  if model not in state["models_in_use"]:
183
  state["models_in_use"].add(model)
184
- lib_logger.info(f"Acquired Tier 2 key ...{key[-6:]} for model {model}")
 
 
185
  return key
186
 
187
  # If all eligible keys are locked, wait for a key to be released.
188
- lib_logger.info("All eligible keys are currently locked for this model. Waiting...")
189
-
 
 
190
  all_potential_keys = tier1_keys + tier2_keys
191
  if not all_potential_keys:
192
- lib_logger.warning("No keys are eligible (all on cooldown). Waiting before re-evaluating.")
 
 
193
  await asyncio.sleep(1)
194
  continue
195
 
196
  # Wait on the condition of the key with the lowest current usage.
197
  best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0]
198
  wait_condition = self.key_states[best_wait_key]["condition"]
199
-
200
  try:
201
  async with wait_condition:
202
  remaining_budget = deadline - time.time()
203
  if remaining_budget <= 0:
204
- break # Exit if the budget has already been exceeded.
205
  # Wait for a notification, but no longer than the remaining budget or 1 second.
206
- await asyncio.wait_for(wait_condition.wait(), timeout=min(1, remaining_budget))
 
 
207
  lib_logger.info("Notified that a key was released. Re-evaluating...")
208
  except asyncio.TimeoutError:
209
  # This is not an error, just a timeout for the wait. The main loop will re-evaluate.
210
  lib_logger.info("Wait timed out. Re-evaluating for any available key.")
211
-
212
- # If the loop exits, it means the deadline was exceeded.
213
- raise NoAvailableKeysError(f"Could not acquire a key for model {model} within the global time budget.")
214
 
 
 
 
 
215
 
216
  async def release_key(self, key: str, model: str):
217
  """Releases a key's lock for a specific model and notifies waiting tasks."""
@@ -224,13 +274,20 @@ class UsageManager:
224
  state["models_in_use"].remove(model)
225
  lib_logger.info(f"Released credential ...{key[-6:]} from model {model}")
226
  else:
227
- lib_logger.warning(f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use.")
 
 
228
 
229
  # Notify all tasks waiting on this key's condition
230
  async with state["condition"]:
231
  state["condition"].notify_all()
232
 
233
- async def record_success(self, key: str, model: str, completion_response: Optional[litellm.ModelResponse] = None):
 
 
 
 
 
234
  """
235
  Records a successful API call, resetting failure counters.
236
  It safely handles cases where token usage data is not available.
@@ -238,33 +295,59 @@ class UsageManager:
238
  await self._lazy_init()
239
  async with self._data_lock:
240
  today_utc_str = datetime.now(timezone.utc).date().isoformat()
241
- key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}})
242
-
 
 
 
 
 
 
 
 
243
  # If the key is new, ensure its reset date is initialized to prevent an immediate reset.
244
  if "last_daily_reset" not in key_data:
245
  key_data["last_daily_reset"] = today_utc_str
246
-
247
  # Always record a success and reset failures
248
  model_failures = key_data.setdefault("failures", {}).setdefault(model, {})
249
  model_failures["consecutive_failures"] = 0
250
  if model in key_data.get("model_cooldowns", {}):
251
  del key_data["model_cooldowns"][model]
252
 
253
- daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0})
 
 
 
 
 
 
 
 
254
  daily_model_data["success_count"] += 1
255
 
256
  # Safely attempt to record token and cost usage
257
- if completion_response and hasattr(completion_response, 'usage') and completion_response.usage:
 
 
 
 
258
  usage = completion_response.usage
259
  daily_model_data["prompt_tokens"] += usage.prompt_tokens
260
- daily_model_data["completion_tokens"] += getattr(usage, 'completion_tokens', 0) # Not present in embedding responses
261
- lib_logger.info(f"Recorded usage from response object for key ...{key[-6:]}")
 
 
 
 
262
  try:
263
- provider_name = model.split('/')[0]
264
  provider_plugin = PROVIDER_PLUGINS.get(provider_name)
265
 
266
- if provider_plugin and provider_plugin.skip_cost_calculation:
267
- lib_logger.debug(f"Skipping cost calculation for provider '{provider_name}' as per its configuration.")
 
 
268
  else:
269
  # Differentiate cost calculation based on response type
270
  if isinstance(completion_response, litellm.EmbeddingResponse):
@@ -272,56 +355,85 @@ class UsageManager:
272
  model_info = litellm.get_model_info(model)
273
  input_cost = model_info.get("input_cost_per_token")
274
  if input_cost:
275
- cost = completion_response.usage.prompt_tokens * input_cost
 
 
276
  else:
277
  cost = None
278
  else:
279
- cost = litellm.completion_cost(completion_response=completion_response, model=model)
280
-
 
 
281
  if cost is not None:
282
  daily_model_data["approx_cost"] += cost
283
  except Exception as e:
284
- lib_logger.warning(f"Could not calculate cost for model {model}: {e}")
285
- elif isinstance(completion_response, asyncio.Future) or hasattr(completion_response, '__aiter__'):
 
 
 
 
286
  # This is an unconsumed stream object. Do not log a warning, as usage will be recorded from the chunks.
287
  pass
288
  else:
289
- lib_logger.warning(f"No usage data found in completion response for model {model}. Recording success without token count.")
 
 
290
 
291
  key_data["last_used_ts"] = time.time()
292
-
293
  await self._save_usage()
294
 
295
- async def record_failure(self, key: str, model: str, classified_error: ClassifiedError):
 
 
296
  """Records a failure and applies cooldowns based on an escalating backoff strategy."""
297
  await self._lazy_init()
298
  async with self._data_lock:
299
  today_utc_str = datetime.now(timezone.utc).date().isoformat()
300
- key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}})
301
-
 
 
 
 
 
 
 
 
302
  # Handle specific error types first
303
- if classified_error.error_type == 'rate_limit' and classified_error.retry_after:
 
 
 
304
  cooldown_seconds = classified_error.retry_after
305
- elif classified_error.error_type == 'authentication':
306
  # Apply a 5-minute key-level lockout for auth errors
307
  key_data["key_cooldown_until"] = time.time() + 300
308
- lib_logger.warning(f"Authentication error on key ...{key[-6:]}. Applying 5-minute key-level lockout.")
 
 
309
  await self._save_usage()
310
- return # No further backoff logic needed
311
  else:
312
  # General backoff logic for other errors
313
  failures_data = key_data.setdefault("failures", {})
314
- model_failures = failures_data.setdefault(model, {"consecutive_failures": 0})
 
 
315
  model_failures["consecutive_failures"] += 1
316
  count = model_failures["consecutive_failures"]
317
 
318
  backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120}
319
- cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours
320
 
321
  # Apply the cooldown
322
  model_cooldowns = key_data.setdefault("model_cooldowns", {})
323
  model_cooldowns[model] = time.time() + cooldown_seconds
324
- lib_logger.warning(f"Failure recorded for key ...{key[-6:]} with model {model}. Applying {cooldown_seconds}s cooldown.")
 
 
325
 
326
  # Check for key-level lockout condition
327
  await self._check_key_lockout(key, key_data)
@@ -329,20 +441,22 @@ class UsageManager:
329
  key_data["last_failure"] = {
330
  "timestamp": time.time(),
331
  "model": model,
332
- "error": str(classified_error.original_exception)
333
  }
334
-
335
  await self._save_usage()
336
 
337
  async def _check_key_lockout(self, key: str, key_data: Dict):
338
  """Checks if a key should be locked out due to multiple model failures."""
339
  long_term_lockout_models = 0
340
  now = time.time()
341
-
342
  for model, cooldown_end in key_data.get("model_cooldowns", {}).items():
343
- if cooldown_end - now >= 7200: # Check for 2-hour lockouts
344
  long_term_lockout_models += 1
345
-
346
  if long_term_lockout_models >= 3:
347
- key_data["key_cooldown_until"] = now + 300 # 5-minute key lockout
348
- lib_logger.error(f"Key ...{key[-6:]} has {long_term_lockout_models} models in long-term lockout. Applying 5-minute key-level lockout.")
 
 
 
11
  from .error_handler import ClassifiedError, NoAvailableKeysError
12
  from .providers import PROVIDER_PLUGINS
13
 
14
+ lib_logger = logging.getLogger("rotator_library")
15
  lib_logger.propagate = False
16
  if not lib_logger.handlers:
17
  lib_logger.addHandler(logging.NullHandler())
18
 
19
+
20
  class UsageManager:
21
  """
22
  Manages usage statistics and cooldowns for API keys with asyncio-safe locking,
23
  asynchronous file I/O, and a lazy-loading mechanism for usage data.
24
  """
25
+
26
+ def __init__(
27
+ self,
28
+ file_path: str = "key_usage.json",
29
+ daily_reset_time_utc: Optional[str] = "03:00",
30
+ ):
31
  self.file_path = file_path
32
  self.key_states: Dict[str, Dict[str, Any]] = {}
33
+
34
  self._data_lock = asyncio.Lock()
35
  self._usage_data: Optional[Dict] = None
36
  self._initialized = asyncio.Event()
 
40
  self._claimed_on_timeout: Set[str] = set()
41
 
42
  if daily_reset_time_utc:
43
+ hour, minute = map(int, daily_reset_time_utc.split(":"))
44
+ self.daily_reset_time_utc = dt_time(
45
+ hour=hour, minute=minute, tzinfo=timezone.utc
46
+ )
47
  else:
48
  self.daily_reset_time_utc = None
49
 
 
62
  self._usage_data = {}
63
  return
64
  try:
65
+ async with aiofiles.open(self.file_path, "r") as f:
66
  content = await f.read()
67
  self._usage_data = json.loads(content)
68
  except (json.JSONDecodeError, IOError, FileNotFoundError):
 
73
  if self._usage_data is None:
74
  return
75
  async with self._data_lock:
76
+ async with aiofiles.open(self.file_path, "w") as f:
77
  await f.write(json.dumps(self._usage_data, indent=2))
78
 
79
  async def _reset_daily_stats_if_needed(self):
 
87
 
88
  for key, data in self._usage_data.items():
89
  last_reset_str = data.get("last_daily_reset", "")
90
+
91
  if last_reset_str != today_str:
92
  last_reset_dt = None
93
  if last_reset_str:
94
  # Ensure the parsed datetime is timezone-aware (UTC)
95
+ last_reset_dt = datetime.fromisoformat(last_reset_str).replace(
96
+ tzinfo=timezone.utc
97
+ )
98
 
99
  # Determine the reset threshold for today
100
+ reset_threshold_today = datetime.combine(
101
+ now_utc.date(), self.daily_reset_time_utc
102
+ )
103
+
104
+ if (
105
+ last_reset_dt is None
106
+ or last_reset_dt < reset_threshold_today <= now_utc
107
+ ):
108
  lib_logger.info(f"Performing daily reset for key ...{key[-6:]}")
109
  needs_saving = True
110
+
111
  # Reset cooldowns
112
  data["model_cooldowns"] = {}
113
  data["key_cooldown_until"] = None
114
+
115
  # Reset consecutive failures
116
  if "failures" in data:
117
  data["failures"] = {}
 
121
  if daily_data:
122
  global_data = data.setdefault("global", {"models": {}})
123
  for model, stats in daily_data.get("models", {}).items():
124
+ global_model_stats = global_data["models"].setdefault(
125
+ model,
126
+ {
127
+ "success_count": 0,
128
+ "prompt_tokens": 0,
129
+ "completion_tokens": 0,
130
+ "approx_cost": 0.0,
131
+ },
132
+ )
133
+ global_model_stats["success_count"] += stats.get(
134
+ "success_count", 0
135
+ )
136
+ global_model_stats["prompt_tokens"] += stats.get(
137
+ "prompt_tokens", 0
138
+ )
139
+ global_model_stats["completion_tokens"] += stats.get(
140
+ "completion_tokens", 0
141
+ )
142
+ global_model_stats["approx_cost"] += stats.get(
143
+ "approx_cost", 0.0
144
+ )
145
+
146
  # Reset daily stats
147
  data["daily"] = {"date": today_str, "models": {}}
148
  data["last_daily_reset"] = today_str
 
157
  self.key_states[key] = {
158
  "lock": asyncio.Lock(),
159
  "condition": asyncio.Condition(),
160
+ "models_in_use": set(),
161
  }
162
 
163
+ async def acquire_key(
164
+ self, available_keys: List[str], model: str, deadline: float
165
+ ) -> str:
166
  """
167
  Acquires the best available key using a tiered, model-aware locking strategy,
168
  respecting a global deadline.
 
175
  while time.time() < deadline:
176
  tier1_keys, tier2_keys = [], []
177
  now = time.time()
178
+
179
  # First, filter the list of available keys to exclude any on cooldown.
180
  async with self._data_lock:
181
  for key in available_keys:
182
  key_data = self._usage_data.get(key, {})
183
+
184
+ if (key_data.get("key_cooldown_until") or 0) > now or (
185
+ key_data.get("model_cooldowns", {}).get(model) or 0
186
+ ) > now:
187
  continue
188
 
189
  # Prioritize keys based on their current usage to ensure load balancing.
190
+ usage_count = (
191
+ key_data.get("daily", {})
192
+ .get("models", {})
193
+ .get(model, {})
194
+ .get("success_count", 0)
195
+ )
196
  key_state = self.key_states[key]
197
 
198
  # Tier 1: Completely idle keys (preferred).
 
211
  async with state["lock"]:
212
  if not state["models_in_use"]:
213
  state["models_in_use"].add(model)
214
+ lib_logger.info(
215
+ f"Acquired Tier 1 key ...{key[-6:]} for model {model}"
216
+ )
217
  return key
218
 
219
  # If no Tier 1 keys are available, try Tier 2.
 
222
  async with state["lock"]:
223
  if model not in state["models_in_use"]:
224
  state["models_in_use"].add(model)
225
+ lib_logger.info(
226
+ f"Acquired Tier 2 key ...{key[-6:]} for model {model}"
227
+ )
228
  return key
229
 
230
  # If all eligible keys are locked, wait for a key to be released.
231
+ lib_logger.info(
232
+ "All eligible keys are currently locked for this model. Waiting..."
233
+ )
234
+
235
  all_potential_keys = tier1_keys + tier2_keys
236
  if not all_potential_keys:
237
+ lib_logger.warning(
238
+ "No keys are eligible (all on cooldown). Waiting before re-evaluating."
239
+ )
240
  await asyncio.sleep(1)
241
  continue
242
 
243
  # Wait on the condition of the key with the lowest current usage.
244
  best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0]
245
  wait_condition = self.key_states[best_wait_key]["condition"]
246
+
247
  try:
248
  async with wait_condition:
249
  remaining_budget = deadline - time.time()
250
  if remaining_budget <= 0:
251
+ break # Exit if the budget has already been exceeded.
252
  # Wait for a notification, but no longer than the remaining budget or 1 second.
253
+ await asyncio.wait_for(
254
+ wait_condition.wait(), timeout=min(1, remaining_budget)
255
+ )
256
  lib_logger.info("Notified that a key was released. Re-evaluating...")
257
  except asyncio.TimeoutError:
258
  # This is not an error, just a timeout for the wait. The main loop will re-evaluate.
259
  lib_logger.info("Wait timed out. Re-evaluating for any available key.")
 
 
 
260
 
261
+ # If the loop exits, it means the deadline was exceeded.
262
+ raise NoAvailableKeysError(
263
+ f"Could not acquire a key for model {model} within the global time budget."
264
+ )
265
 
266
  async def release_key(self, key: str, model: str):
267
  """Releases a key's lock for a specific model and notifies waiting tasks."""
 
274
  state["models_in_use"].remove(model)
275
  lib_logger.info(f"Released credential ...{key[-6:]} from model {model}")
276
  else:
277
+ lib_logger.warning(
278
+ f"Attempted to release credential ...{key[-6:]} for model {model}, but it was not in use."
279
+ )
280
 
281
  # Notify all tasks waiting on this key's condition
282
  async with state["condition"]:
283
  state["condition"].notify_all()
284
 
285
+ async def record_success(
286
+ self,
287
+ key: str,
288
+ model: str,
289
+ completion_response: Optional[litellm.ModelResponse] = None,
290
+ ):
291
  """
292
  Records a successful API call, resetting failure counters.
293
  It safely handles cases where token usage data is not available.
 
295
  await self._lazy_init()
296
  async with self._data_lock:
297
  today_utc_str = datetime.now(timezone.utc).date().isoformat()
298
+ key_data = self._usage_data.setdefault(
299
+ key,
300
+ {
301
+ "daily": {"date": today_utc_str, "models": {}},
302
+ "global": {"models": {}},
303
+ "model_cooldowns": {},
304
+ "failures": {},
305
+ },
306
+ )
307
+
308
  # If the key is new, ensure its reset date is initialized to prevent an immediate reset.
309
  if "last_daily_reset" not in key_data:
310
  key_data["last_daily_reset"] = today_utc_str
311
+
312
  # Always record a success and reset failures
313
  model_failures = key_data.setdefault("failures", {}).setdefault(model, {})
314
  model_failures["consecutive_failures"] = 0
315
  if model in key_data.get("model_cooldowns", {}):
316
  del key_data["model_cooldowns"][model]
317
 
318
+ daily_model_data = key_data["daily"]["models"].setdefault(
319
+ model,
320
+ {
321
+ "success_count": 0,
322
+ "prompt_tokens": 0,
323
+ "completion_tokens": 0,
324
+ "approx_cost": 0.0,
325
+ },
326
+ )
327
  daily_model_data["success_count"] += 1
328
 
329
  # Safely attempt to record token and cost usage
330
+ if (
331
+ completion_response
332
+ and hasattr(completion_response, "usage")
333
+ and completion_response.usage
334
+ ):
335
  usage = completion_response.usage
336
  daily_model_data["prompt_tokens"] += usage.prompt_tokens
337
+ daily_model_data["completion_tokens"] += getattr(
338
+ usage, "completion_tokens", 0
339
+ ) # Not present in embedding responses
340
+ lib_logger.info(
341
+ f"Recorded usage from response object for key ...{key[-6:]}"
342
+ )
343
  try:
344
+ provider_name = model.split("/")[0]
345
  provider_plugin = PROVIDER_PLUGINS.get(provider_name)
346
 
347
+ if provider_plugin and provider_plugin.skip_cost_calculation():
348
+ lib_logger.debug(
349
+ f"Skipping cost calculation for provider '{provider_name}' (custom provider)."
350
+ )
351
  else:
352
  # Differentiate cost calculation based on response type
353
  if isinstance(completion_response, litellm.EmbeddingResponse):
 
355
  model_info = litellm.get_model_info(model)
356
  input_cost = model_info.get("input_cost_per_token")
357
  if input_cost:
358
+ cost = (
359
+ completion_response.usage.prompt_tokens * input_cost
360
+ )
361
  else:
362
  cost = None
363
  else:
364
+ cost = litellm.completion_cost(
365
+ completion_response=completion_response, model=model
366
+ )
367
+
368
  if cost is not None:
369
  daily_model_data["approx_cost"] += cost
370
  except Exception as e:
371
+ lib_logger.warning(
372
+ f"Could not calculate cost for model {model}: {e}"
373
+ )
374
+ elif isinstance(completion_response, asyncio.Future) or hasattr(
375
+ completion_response, "__aiter__"
376
+ ):
377
  # This is an unconsumed stream object. Do not log a warning, as usage will be recorded from the chunks.
378
  pass
379
  else:
380
+ lib_logger.warning(
381
+ f"No usage data found in completion response for model {model}. Recording success without token count."
382
+ )
383
 
384
  key_data["last_used_ts"] = time.time()
385
+
386
  await self._save_usage()
387
 
388
+ async def record_failure(
389
+ self, key: str, model: str, classified_error: ClassifiedError
390
+ ):
391
  """Records a failure and applies cooldowns based on an escalating backoff strategy."""
392
  await self._lazy_init()
393
  async with self._data_lock:
394
  today_utc_str = datetime.now(timezone.utc).date().isoformat()
395
+ key_data = self._usage_data.setdefault(
396
+ key,
397
+ {
398
+ "daily": {"date": today_utc_str, "models": {}},
399
+ "global": {"models": {}},
400
+ "model_cooldowns": {},
401
+ "failures": {},
402
+ },
403
+ )
404
+
405
  # Handle specific error types first
406
+ if (
407
+ classified_error.error_type == "rate_limit"
408
+ and classified_error.retry_after
409
+ ):
410
  cooldown_seconds = classified_error.retry_after
411
+ elif classified_error.error_type == "authentication":
412
  # Apply a 5-minute key-level lockout for auth errors
413
  key_data["key_cooldown_until"] = time.time() + 300
414
+ lib_logger.warning(
415
+ f"Authentication error on key ...{key[-6:]}. Applying 5-minute key-level lockout."
416
+ )
417
  await self._save_usage()
418
+ return # No further backoff logic needed
419
  else:
420
  # General backoff logic for other errors
421
  failures_data = key_data.setdefault("failures", {})
422
+ model_failures = failures_data.setdefault(
423
+ model, {"consecutive_failures": 0}
424
+ )
425
  model_failures["consecutive_failures"] += 1
426
  count = model_failures["consecutive_failures"]
427
 
428
  backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120}
429
+ cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours
430
 
431
  # Apply the cooldown
432
  model_cooldowns = key_data.setdefault("model_cooldowns", {})
433
  model_cooldowns[model] = time.time() + cooldown_seconds
434
+ lib_logger.warning(
435
+ f"Failure recorded for key ...{key[-6:]} with model {model}. Applying {cooldown_seconds}s cooldown."
436
+ )
437
 
438
  # Check for key-level lockout condition
439
  await self._check_key_lockout(key, key_data)
 
441
  key_data["last_failure"] = {
442
  "timestamp": time.time(),
443
  "model": model,
444
+ "error": str(classified_error.original_exception),
445
  }
446
+
447
  await self._save_usage()
448
 
449
  async def _check_key_lockout(self, key: str, key_data: Dict):
450
  """Checks if a key should be locked out due to multiple model failures."""
451
  long_term_lockout_models = 0
452
  now = time.time()
453
+
454
  for model, cooldown_end in key_data.get("model_cooldowns", {}).items():
455
+ if cooldown_end - now >= 7200: # Check for 2-hour lockouts
456
  long_term_lockout_models += 1
457
+
458
  if long_term_lockout_models >= 3:
459
+ key_data["key_cooldown_until"] = now + 300 # 5-minute key lockout
460
+ lib_logger.error(
461
+ f"Key ...{key[-6:]} has {long_term_lockout_models} models in long-term lockout. Applying 5-minute key-level lockout."
462
+ )