Mirrowel commited on
Commit
4dfb828
·
1 Parent(s): f03c448

feat(providers): ✨ implement credential tier initialization and persistence system

Browse files

This commit introduces a comprehensive credential tier management system across the library, enabling automatic tier detection, persistence, and intelligent credential prioritization at startup.

- Add `initialize_credentials()` method to `ProviderInterface` for startup credential loading
- Add `get_credential_tier_name()` method to expose human-readable tier names for logging
- Implement tier persistence in credential files via `_proxy_metadata` field
- Add lazy-loading fallback for tier data when not in memory cache
- Introduce `BackgroundRefresher._initialize_credentials()` to pre-load all provider tiers before refresh loop
- Pass `credential_tier_names` map through client to usage_manager for enhanced logging
- Update `UsageManager.acquire_key()` to display tier information in acquisition logs
- Make `ModelDefinitions` a singleton to prevent duplicate loading across providers
- Add comprehensive 3-line startup summary showing provider counts, credentials, and tier breakdown
- Implement tier-aware logging in Antigravity and GeminiCli providers with disk persistence
- Fix provider instance lookup for OAuth providers by handling `_oauth` suffix correctly

This ensures all credential priorities are known before any API calls, preventing unknown credentials from getting priority 999 and improving load balancing from the first request.

src/rotator_library/background_refresher.py CHANGED
@@ -8,28 +8,35 @@ from typing import TYPE_CHECKING, Optional
8
  if TYPE_CHECKING:
9
  from .client import RotatingClient
10
 
11
- lib_logger = logging.getLogger('rotator_library')
 
12
 
13
  class BackgroundRefresher:
14
  """
15
  A background task that periodically checks and refreshes OAuth tokens
16
  to ensure they remain valid.
17
  """
18
- def __init__(self, client: 'RotatingClient'):
 
19
  try:
20
  interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600")
21
  self._interval = int(interval_str)
22
  except ValueError:
23
- lib_logger.warning(f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s.")
 
 
24
  self._interval = 600
25
  self._client = client
26
  self._task: Optional[asyncio.Task] = None
 
27
 
28
  def start(self):
29
  """Starts the background refresh task."""
30
  if self._task is None:
31
  self._task = asyncio.create_task(self._run())
32
- lib_logger.info(f"Background token refresher started. Check interval: {self._interval} seconds.")
 
 
33
  # [NEW] Log if custom interval is set
34
 
35
  async def stop(self):
@@ -42,23 +49,107 @@ class BackgroundRefresher:
42
  pass
43
  lib_logger.info("Background token refresher stopped.")
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  async def _run(self):
46
  """The main loop for the background task."""
 
 
 
47
  while True:
48
  try:
49
- #lib_logger.info("Running proactive token refresh check...")
50
 
51
  oauth_configs = self._client.get_oauth_credentials()
52
  for provider, paths in oauth_configs.items():
53
- provider_plugin = self._client._get_provider_instance(f"{provider}_oauth")
54
- if provider_plugin and hasattr(provider_plugin, 'proactively_refresh'):
 
 
55
  for path in paths:
56
  try:
57
  await provider_plugin.proactively_refresh(path)
58
  except Exception as e:
59
- lib_logger.error(f"Error during proactive refresh for '{path}': {e}")
 
 
60
  await asyncio.sleep(self._interval)
61
  except asyncio.CancelledError:
62
  break
63
  except Exception as e:
64
- lib_logger.error(f"Unexpected error in background refresher loop: {e}")
 
8
  if TYPE_CHECKING:
9
  from .client import RotatingClient
10
 
11
+ lib_logger = logging.getLogger("rotator_library")
12
+
13
 
14
  class BackgroundRefresher:
15
  """
16
  A background task that periodically checks and refreshes OAuth tokens
17
  to ensure they remain valid.
18
  """
19
+
20
+ def __init__(self, client: "RotatingClient"):
21
  try:
22
  interval_str = os.getenv("OAUTH_REFRESH_INTERVAL", "600")
23
  self._interval = int(interval_str)
24
  except ValueError:
25
+ lib_logger.warning(
26
+ f"Invalid OAUTH_REFRESH_INTERVAL '{interval_str}'. Falling back to 600s."
27
+ )
28
  self._interval = 600
29
  self._client = client
30
  self._task: Optional[asyncio.Task] = None
31
+ self._initialized = False
32
 
33
  def start(self):
34
  """Starts the background refresh task."""
35
  if self._task is None:
36
  self._task = asyncio.create_task(self._run())
37
+ lib_logger.info(
38
+ f"Background token refresher started. Check interval: {self._interval} seconds."
39
+ )
40
  # [NEW] Log if custom interval is set
41
 
42
  async def stop(self):
 
49
  pass
50
  lib_logger.info("Background token refresher stopped.")
51
 
52
+ async def _initialize_credentials(self):
53
+ """
54
+ Initialize all providers by loading credentials and persisted tier data.
55
+ Called once before the main refresh loop starts.
56
+ """
57
+ if self._initialized:
58
+ return
59
+
60
+ api_summary = {} # provider -> count
61
+ oauth_summary = {} # provider -> {"count": N, "tiers": {tier: count}}
62
+
63
+ all_credentials = self._client.all_credentials
64
+ oauth_providers = self._client.oauth_providers
65
+
66
+ for provider, credentials in all_credentials.items():
67
+ if not credentials:
68
+ continue
69
+
70
+ provider_plugin = self._client._get_provider_instance(provider)
71
+
72
+ # Call initialize_credentials if provider supports it
73
+ if provider_plugin and hasattr(provider_plugin, "initialize_credentials"):
74
+ try:
75
+ await provider_plugin.initialize_credentials(credentials)
76
+ except Exception as e:
77
+ lib_logger.error(
78
+ f"Error initializing credentials for provider '{provider}': {e}"
79
+ )
80
+
81
+ # Build summary based on provider type
82
+ if provider in oauth_providers:
83
+ tier_breakdown = {}
84
+ if provider_plugin and hasattr(
85
+ provider_plugin, "get_credential_tier_name"
86
+ ):
87
+ for cred in credentials:
88
+ tier = provider_plugin.get_credential_tier_name(cred)
89
+ if tier:
90
+ tier_breakdown[tier] = tier_breakdown.get(tier, 0) + 1
91
+ oauth_summary[provider] = {
92
+ "count": len(credentials),
93
+ "tiers": tier_breakdown,
94
+ }
95
+ else:
96
+ api_summary[provider] = len(credentials)
97
+
98
+ # Log 3-line summary
99
+ total_providers = len(api_summary) + len(oauth_summary)
100
+ total_credentials = sum(api_summary.values()) + sum(
101
+ d["count"] for d in oauth_summary.values()
102
+ )
103
+
104
+ if total_providers > 0:
105
+ lib_logger.info(
106
+ f"Providers initialized: {total_providers} providers, {total_credentials} credentials"
107
+ )
108
+
109
+ # API providers line
110
+ if api_summary:
111
+ api_parts = [f"{p}:{c}" for p, c in sorted(api_summary.items())]
112
+ lib_logger.info(f" API: {', '.join(api_parts)}")
113
+
114
+ # OAuth providers line with tier breakdown
115
+ if oauth_summary:
116
+ oauth_parts = []
117
+ for provider, data in sorted(oauth_summary.items()):
118
+ if data["tiers"]:
119
+ tier_str = ", ".join(
120
+ f"{t}:{c}" for t, c in sorted(data["tiers"].items())
121
+ )
122
+ oauth_parts.append(f"{provider}:{data['count']} ({tier_str})")
123
+ else:
124
+ oauth_parts.append(f"{provider}:{data['count']}")
125
+ lib_logger.info(f" OAuth: {', '.join(oauth_parts)}")
126
+
127
+ self._initialized = True
128
+
129
  async def _run(self):
130
  """The main loop for the background task."""
131
+ # Initialize credentials (load persisted tiers) before starting the refresh loop
132
+ await self._initialize_credentials()
133
+
134
  while True:
135
  try:
136
+ # lib_logger.info("Running proactive token refresh check...")
137
 
138
  oauth_configs = self._client.get_oauth_credentials()
139
  for provider, paths in oauth_configs.items():
140
+ provider_plugin = self._client._get_provider_instance(provider)
141
+ if provider_plugin and hasattr(
142
+ provider_plugin, "proactively_refresh"
143
+ ):
144
  for path in paths:
145
  try:
146
  await provider_plugin.proactively_refresh(path)
147
  except Exception as e:
148
+ lib_logger.error(
149
+ f"Error during proactive refresh for '{path}': {e}"
150
+ )
151
  await asyncio.sleep(self._interval)
152
  except asyncio.CancelledError:
153
  break
154
  except Exception as e:
155
+ lib_logger.error(f"Unexpected error in background refresher loop: {e}")
src/rotator_library/client.py CHANGED
@@ -447,12 +447,23 @@ class RotatingClient:
447
 
448
  Args:
449
  provider_name: The name of the provider to get an instance for.
 
 
 
450
 
451
  Returns:
452
  Provider instance if credentials exist, None otherwise.
453
  """
 
 
 
 
 
 
 
 
454
  # Only initialize providers for which we have credentials
455
- if provider_name not in self.all_credentials:
456
  lib_logger.debug(
457
  f"Skipping provider '{provider_name}' initialization: no credentials configured"
458
  )
@@ -824,13 +835,20 @@ class RotatingClient:
824
  f"Request will likely fail."
825
  )
826
 
827
- # Build priority map for usage_manager
 
828
  if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
829
  credential_priorities = {}
 
830
  for cred in credentials_for_provider:
831
  priority = provider_plugin.get_credential_priority(cred)
832
  if priority is not None:
833
  credential_priorities[cred] = priority
 
 
 
 
 
834
 
835
  if credential_priorities:
836
  lib_logger.debug(
@@ -883,6 +901,7 @@ class RotatingClient:
883
  deadline=deadline,
884
  max_concurrent=max_concurrent,
885
  credential_priorities=credential_priorities,
 
886
  )
887
  key_acquired = True
888
  tried_creds.add(current_cred)
@@ -1371,13 +1390,20 @@ class RotatingClient:
1371
  f"Request will likely fail."
1372
  )
1373
 
1374
- # Build priority map for usage_manager
 
1375
  if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
1376
  credential_priorities = {}
 
1377
  for cred in credentials_for_provider:
1378
  priority = provider_plugin.get_credential_priority(cred)
1379
  if priority is not None:
1380
  credential_priorities[cred] = priority
 
 
 
 
 
1381
 
1382
  if credential_priorities:
1383
  lib_logger.debug(
@@ -1433,6 +1459,7 @@ class RotatingClient:
1433
  deadline=deadline,
1434
  max_concurrent=max_concurrent,
1435
  credential_priorities=credential_priorities,
 
1436
  )
1437
  key_acquired = True
1438
  tried_creds.add(current_cred)
 
447
 
448
  Args:
449
  provider_name: The name of the provider to get an instance for.
450
+ For OAuth providers, this may include "_oauth" suffix
451
+ (e.g., "antigravity_oauth"), but credentials are stored
452
+ under the base name (e.g., "antigravity").
453
 
454
  Returns:
455
  Provider instance if credentials exist, None otherwise.
456
  """
457
+ # For OAuth providers, credentials are stored under base name (without _oauth suffix)
458
+ # e.g., "antigravity_oauth" plugin → credentials under "antigravity"
459
+ credential_key = provider_name
460
+ if provider_name.endswith("_oauth"):
461
+ base_name = provider_name[:-6] # Remove "_oauth"
462
+ if base_name in self.oauth_providers:
463
+ credential_key = base_name
464
+
465
  # Only initialize providers for which we have credentials
466
+ if credential_key not in self.all_credentials:
467
  lib_logger.debug(
468
  f"Skipping provider '{provider_name}' initialization: no credentials configured"
469
  )
 
835
  f"Request will likely fail."
836
  )
837
 
838
+ # Build priority map and tier names map for usage_manager
839
+ credential_tier_names = None
840
  if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
841
  credential_priorities = {}
842
+ credential_tier_names = {}
843
  for cred in credentials_for_provider:
844
  priority = provider_plugin.get_credential_priority(cred)
845
  if priority is not None:
846
  credential_priorities[cred] = priority
847
+ # Also get tier name for logging
848
+ if hasattr(provider_plugin, "get_credential_tier_name"):
849
+ tier_name = provider_plugin.get_credential_tier_name(cred)
850
+ if tier_name:
851
+ credential_tier_names[cred] = tier_name
852
 
853
  if credential_priorities:
854
  lib_logger.debug(
 
901
  deadline=deadline,
902
  max_concurrent=max_concurrent,
903
  credential_priorities=credential_priorities,
904
+ credential_tier_names=credential_tier_names,
905
  )
906
  key_acquired = True
907
  tried_creds.add(current_cred)
 
1390
  f"Request will likely fail."
1391
  )
1392
 
1393
+ # Build priority map and tier names map for usage_manager
1394
+ credential_tier_names = None
1395
  if provider_plugin and hasattr(provider_plugin, "get_credential_priority"):
1396
  credential_priorities = {}
1397
+ credential_tier_names = {}
1398
  for cred in credentials_for_provider:
1399
  priority = provider_plugin.get_credential_priority(cred)
1400
  if priority is not None:
1401
  credential_priorities[cred] = priority
1402
+ # Also get tier name for logging
1403
+ if hasattr(provider_plugin, "get_credential_tier_name"):
1404
+ tier_name = provider_plugin.get_credential_tier_name(cred)
1405
+ if tier_name:
1406
+ credential_tier_names[cred] = tier_name
1407
 
1408
  if credential_priorities:
1409
  lib_logger.debug(
 
1459
  deadline=deadline,
1460
  max_concurrent=max_concurrent,
1461
  credential_priorities=credential_priorities,
1462
+ credential_tier_names=credential_tier_names,
1463
  )
1464
  key_acquired = True
1465
  tried_creds.add(current_cred)
src/rotator_library/model_definitions.py CHANGED
@@ -24,10 +24,23 @@ class ModelDefinitions:
24
  - IFLOW_MODELS='{"glm-4.6": {}}' - dict format, uses "glm-4.6" as both name and ID
25
  - IFLOW_MODELS='{"custom-name": {"id": "actual-id"}}' - dict format with custom ID
26
  - IFLOW_MODELS='{"model": {"id": "id", "options": {"temperature": 0.7}}}' - with options
 
 
27
  """
28
 
 
 
 
 
 
 
 
 
29
  def __init__(self, config_path: Optional[str] = None):
30
- """Initialize model definitions loader."""
 
 
 
31
  self.config_path = config_path
32
  self.definitions = {}
33
  self._load_definitions()
@@ -49,7 +62,11 @@ class ModelDefinitions:
49
  # Handle array format: ["model-1", "model-2", "model-3"]
50
  elif isinstance(models_json, list):
51
  # Convert array to dict format with empty definitions
52
- models_dict = {model_name: {} for model_name in models_json if isinstance(model_name, str)}
 
 
 
 
53
  self.definitions[provider_name] = models_dict
54
  lib_logger.info(
55
  f"Loaded {len(models_dict)} models for provider: {provider_name} (array format)"
 
24
  - IFLOW_MODELS='{"glm-4.6": {}}' - dict format, uses "glm-4.6" as both name and ID
25
  - IFLOW_MODELS='{"custom-name": {"id": "actual-id"}}' - dict format with custom ID
26
  - IFLOW_MODELS='{"model": {"id": "id", "options": {"temperature": 0.7}}}' - with options
27
+
28
+ This class is a singleton - instantiated once and shared across all providers.
29
  """
30
 
31
+ _instance: Optional["ModelDefinitions"] = None
32
+ _initialized: bool = False
33
+
34
+ def __new__(cls, config_path: Optional[str] = None):
35
+ if cls._instance is None:
36
+ cls._instance = super().__new__(cls)
37
+ return cls._instance
38
+
39
  def __init__(self, config_path: Optional[str] = None):
40
+ """Initialize model definitions loader (only runs once due to singleton)."""
41
+ if ModelDefinitions._initialized:
42
+ return
43
+ ModelDefinitions._initialized = True
44
  self.config_path = config_path
45
  self.definitions = {}
46
  self._load_definitions()
 
62
  # Handle array format: ["model-1", "model-2", "model-3"]
63
  elif isinstance(models_json, list):
64
  # Convert array to dict format with empty definitions
65
+ models_dict = {
66
+ model_name: {}
67
+ for model_name in models_json
68
+ if isinstance(model_name, str)
69
+ }
70
  self.definitions[provider_name] = models_dict
71
  lib_logger.info(
72
  f"Loaded {len(models_dict)} models for provider: {provider_name} (array format)"
src/rotator_library/providers/antigravity_provider.py CHANGED
@@ -595,6 +595,11 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
595
  Priority level (1-10) or None if tier not yet discovered
596
  """
597
  tier = self.project_tier_cache.get(credential)
 
 
 
 
 
598
  if not tier:
599
  return None # Not yet discovered
600
 
@@ -609,6 +614,60 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
609
  # Legacy and unknown get even lower
610
  return 10
611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  def get_model_tier_requirement(self, model: str) -> Optional[int]:
613
  """
614
  Returns the minimum priority tier required for a model.
@@ -622,6 +681,72 @@ class AntigravityProvider(AntigravityAuthBase, ProviderInterface):
622
  """
623
  return None
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  # =========================================================================
626
  # MODEL UTILITIES
627
  # =========================================================================
 
595
  Priority level (1-10) or None if tier not yet discovered
596
  """
597
  tier = self.project_tier_cache.get(credential)
598
+
599
+ # Lazy load from file if not in cache
600
+ if not tier:
601
+ tier = self._load_tier_from_file(credential)
602
+
603
  if not tier:
604
  return None # Not yet discovered
605
 
 
614
  # Legacy and unknown get even lower
615
  return 10
616
 
617
+ def _load_tier_from_file(self, credential_path: str) -> Optional[str]:
618
+ """
619
+ Load tier from credential file's _proxy_metadata and cache it.
620
+
621
+ This is used as a fallback when the tier isn't in the memory cache,
622
+ typically on first access before initialize_credentials() has run.
623
+
624
+ Args:
625
+ credential_path: Path to the credential file
626
+
627
+ Returns:
628
+ Tier string if found, None otherwise
629
+ """
630
+ # Skip env:// paths (environment-based credentials)
631
+ if self._parse_env_credential_path(credential_path) is not None:
632
+ return None
633
+
634
+ try:
635
+ with open(credential_path, "r") as f:
636
+ creds = json.load(f)
637
+
638
+ metadata = creds.get("_proxy_metadata", {})
639
+ tier = metadata.get("tier")
640
+ project_id = metadata.get("project_id")
641
+
642
+ if tier:
643
+ self.project_tier_cache[credential_path] = tier
644
+ lib_logger.debug(
645
+ f"Lazy-loaded tier '{tier}' for credential: {Path(credential_path).name}"
646
+ )
647
+
648
+ if project_id and credential_path not in self.project_id_cache:
649
+ self.project_id_cache[credential_path] = project_id
650
+
651
+ return tier
652
+ except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
653
+ lib_logger.debug(f"Could not lazy-load tier from {credential_path}: {e}")
654
+ return None
655
+
656
+ def get_credential_tier_name(self, credential: str) -> Optional[str]:
657
+ """
658
+ Returns the human-readable tier name for a credential.
659
+
660
+ Args:
661
+ credential: The credential path
662
+
663
+ Returns:
664
+ Tier name string (e.g., "free-tier") or None if unknown
665
+ """
666
+ tier = self.project_tier_cache.get(credential)
667
+ if not tier:
668
+ tier = self._load_tier_from_file(credential)
669
+ return tier
670
+
671
  def get_model_tier_requirement(self, model: str) -> Optional[int]:
672
  """
673
  Returns the minimum priority tier required for a model.
 
681
  """
682
  return None
683
 
684
+ async def initialize_credentials(self, credential_paths: List[str]) -> None:
685
+ """
686
+ Load persisted tier information from credential files at startup.
687
+
688
+ This ensures all credential priorities are known before any API calls,
689
+ preventing unknown credentials from getting priority 999.
690
+ """
691
+ await self._load_persisted_tiers(credential_paths)
692
+
693
+ async def _load_persisted_tiers(
694
+ self, credential_paths: List[str]
695
+ ) -> Dict[str, str]:
696
+ """
697
+ Load persisted tier information from credential files into memory cache.
698
+
699
+ Args:
700
+ credential_paths: List of credential file paths
701
+
702
+ Returns:
703
+ Dict mapping credential path to tier name for logging purposes
704
+ """
705
+ loaded = {}
706
+ for path in credential_paths:
707
+ # Skip env:// paths (environment-based credentials)
708
+ if self._parse_env_credential_path(path) is not None:
709
+ continue
710
+
711
+ # Skip if already in cache
712
+ if path in self.project_tier_cache:
713
+ continue
714
+
715
+ try:
716
+ with open(path, "r") as f:
717
+ creds = json.load(f)
718
+
719
+ metadata = creds.get("_proxy_metadata", {})
720
+ tier = metadata.get("tier")
721
+ project_id = metadata.get("project_id")
722
+
723
+ if tier:
724
+ self.project_tier_cache[path] = tier
725
+ loaded[path] = tier
726
+ lib_logger.debug(
727
+ f"Loaded persisted tier '{tier}' for credential: {Path(path).name}"
728
+ )
729
+
730
+ if project_id:
731
+ self.project_id_cache[path] = project_id
732
+
733
+ except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
734
+ lib_logger.debug(f"Could not load persisted tier from {path}: {e}")
735
+
736
+ if loaded:
737
+ # Log summary at debug level
738
+ tier_counts: Dict[str, int] = {}
739
+ for tier in loaded.values():
740
+ tier_counts[tier] = tier_counts.get(tier, 0) + 1
741
+ lib_logger.debug(
742
+ f"Antigravity: Loaded {len(loaded)} credential tiers from disk: "
743
+ + ", ".join(
744
+ f"{tier}={count}" for tier, count in sorted(tier_counts.items())
745
+ )
746
+ )
747
+
748
+ return loaded
749
+
750
  # =========================================================================
751
  # MODEL UTILITIES
752
  # =========================================================================
src/rotator_library/providers/gemini_cli_provider.py CHANGED
The diff for this file is too large to render. See raw diff
 
src/rotator_library/providers/provider_interface.py CHANGED
@@ -3,13 +3,15 @@ from typing import List, Dict, Any, Optional, AsyncGenerator, Union
3
  import httpx
4
  import litellm
5
 
 
6
  class ProviderInterface(ABC):
7
  """
8
  An interface for API provider-specific functionality, including model
9
  discovery and custom API call handling for non-standard providers.
10
  """
 
11
  skip_cost_calculation: bool = False
12
-
13
  @abstractmethod
14
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
15
  """
@@ -32,28 +34,38 @@ class ProviderInterface(ABC):
32
  """
33
  return False
34
 
35
- async def acompletion(self, client: httpx.AsyncClient, **kwargs) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
 
 
36
  """
37
  Handles the entire completion call for non-standard providers.
38
  """
39
- raise NotImplementedError(f"{self.__class__.__name__} does not implement custom acompletion.")
 
 
40
 
41
- async def aembedding(self, client: httpx.AsyncClient, **kwargs) -> litellm.EmbeddingResponse:
 
 
42
  """Handles the entire embedding call for non-standard providers."""
43
- raise NotImplementedError(f"{self.__class__.__name__} does not implement custom aembedding.")
44
-
45
- def convert_safety_settings(self, settings: Dict[str, str]) -> Optional[List[Dict[str, Any]]]:
 
 
 
 
46
  """
47
  Converts a generic safety settings dictionary to the provider-specific format.
48
-
49
  Args:
50
  settings: A dictionary with generic harm categories and thresholds.
51
-
52
  Returns:
53
  A list of provider-specific safety setting objects or None.
54
  """
55
  return None
56
-
57
  # [NEW] Add new methods for OAuth providers
58
  async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
59
  """
@@ -67,23 +79,23 @@ class ProviderInterface(ABC):
67
  Proactively refreshes a token if it's nearing expiry.
68
  """
69
  pass
70
-
71
  # [NEW] Credential Prioritization System
72
  def get_credential_priority(self, credential: str) -> Optional[int]:
73
  """
74
  Returns the priority level for a credential.
75
  Lower numbers = higher priority (1 is highest).
76
  Returns None if provider doesn't use priorities.
77
-
78
  This allows providers to auto-detect credential tiers (e.g., paid vs free)
79
  and ensure higher-tier credentials are always tried first.
80
-
81
  Args:
82
  credential: The credential identifier (API key or path)
83
-
84
  Returns:
85
  Priority level (1-10) or None if no priority system
86
-
87
  Example:
88
  For Gemini CLI:
89
  - Paid tier credentials: priority 1 (highest)
@@ -91,24 +103,53 @@ class ProviderInterface(ABC):
91
  - Unknown tier: priority 10 (lowest)
92
  """
93
  return None
94
-
95
  def get_model_tier_requirement(self, model: str) -> Optional[int]:
96
  """
97
  Returns the minimum priority tier required for a model.
98
  If a model requires priority 1, only credentials with priority <= 1 can use it.
99
-
100
  This allows providers to restrict certain models to specific credential tiers.
101
  For example, Gemini 3 models require paid-tier credentials.
102
-
103
  Args:
104
  model: The model name (with or without provider prefix)
105
-
106
  Returns:
107
  Minimum required priority level or None if no restrictions
108
-
109
  Example:
110
  For Gemini CLI:
111
  - gemini-3-*: requires priority 1 (paid tier only)
112
  - gemini-2.5-*: no restriction (None)
113
  """
114
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import httpx
4
  import litellm
5
 
6
+
7
  class ProviderInterface(ABC):
8
  """
9
  An interface for API provider-specific functionality, including model
10
  discovery and custom API call handling for non-standard providers.
11
  """
12
+
13
  skip_cost_calculation: bool = False
14
+
15
  @abstractmethod
16
  async def get_models(self, api_key: str, client: httpx.AsyncClient) -> List[str]:
17
  """
 
34
  """
35
  return False
36
 
37
+ async def acompletion(
38
+ self, client: httpx.AsyncClient, **kwargs
39
+ ) -> Union[litellm.ModelResponse, AsyncGenerator[litellm.ModelResponse, None]]:
40
  """
41
  Handles the entire completion call for non-standard providers.
42
  """
43
+ raise NotImplementedError(
44
+ f"{self.__class__.__name__} does not implement custom acompletion."
45
+ )
46
 
47
+ async def aembedding(
48
+ self, client: httpx.AsyncClient, **kwargs
49
+ ) -> litellm.EmbeddingResponse:
50
  """Handles the entire embedding call for non-standard providers."""
51
+ raise NotImplementedError(
52
+ f"{self.__class__.__name__} does not implement custom aembedding."
53
+ )
54
+
55
+ def convert_safety_settings(
56
+ self, settings: Dict[str, str]
57
+ ) -> Optional[List[Dict[str, Any]]]:
58
  """
59
  Converts a generic safety settings dictionary to the provider-specific format.
60
+
61
  Args:
62
  settings: A dictionary with generic harm categories and thresholds.
63
+
64
  Returns:
65
  A list of provider-specific safety setting objects or None.
66
  """
67
  return None
68
+
69
  # [NEW] Add new methods for OAuth providers
70
  async def get_auth_header(self, credential_identifier: str) -> Dict[str, str]:
71
  """
 
79
  Proactively refreshes a token if it's nearing expiry.
80
  """
81
  pass
82
+
83
  # [NEW] Credential Prioritization System
84
  def get_credential_priority(self, credential: str) -> Optional[int]:
85
  """
86
  Returns the priority level for a credential.
87
  Lower numbers = higher priority (1 is highest).
88
  Returns None if provider doesn't use priorities.
89
+
90
  This allows providers to auto-detect credential tiers (e.g., paid vs free)
91
  and ensure higher-tier credentials are always tried first.
92
+
93
  Args:
94
  credential: The credential identifier (API key or path)
95
+
96
  Returns:
97
  Priority level (1-10) or None if no priority system
98
+
99
  Example:
100
  For Gemini CLI:
101
  - Paid tier credentials: priority 1 (highest)
 
103
  - Unknown tier: priority 10 (lowest)
104
  """
105
  return None
106
+
107
  def get_model_tier_requirement(self, model: str) -> Optional[int]:
108
  """
109
  Returns the minimum priority tier required for a model.
110
  If a model requires priority 1, only credentials with priority <= 1 can use it.
111
+
112
  This allows providers to restrict certain models to specific credential tiers.
113
  For example, Gemini 3 models require paid-tier credentials.
114
+
115
  Args:
116
  model: The model name (with or without provider prefix)
117
+
118
  Returns:
119
  Minimum required priority level or None if no restrictions
120
+
121
  Example:
122
  For Gemini CLI:
123
  - gemini-3-*: requires priority 1 (paid tier only)
124
  - gemini-2.5-*: no restriction (None)
125
  """
126
+ return None
127
+
128
+ async def initialize_credentials(self, credential_paths: List[str]) -> None:
129
+ """
130
+ Called at startup to initialize provider with all available credentials.
131
+
132
+ Providers can override this to load cached tier data, discover priorities,
133
+ or perform any other initialization needed before the first API request.
134
+
135
+ This is called once during startup by the BackgroundRefresher before
136
+ the main refresh loop begins.
137
+
138
+ Args:
139
+ credential_paths: List of credential file paths for this provider
140
+ """
141
+ pass
142
+
143
+ def get_credential_tier_name(self, credential: str) -> Optional[str]:
144
+ """
145
+ Returns the human-readable tier name for a credential.
146
+
147
+ This is used for logging purposes to show which plan tier a credential belongs to.
148
+
149
+ Args:
150
+ credential: The credential identifier (API key or path)
151
+
152
+ Returns:
153
+ Tier name string (e.g., "free-tier", "paid-tier") or None if unknown
154
+ """
155
+ return None
src/rotator_library/usage_manager.py CHANGED
@@ -22,24 +22,24 @@ class UsageManager:
22
  """
23
  Manages usage statistics and cooldowns for API keys with asyncio-safe locking,
24
  asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation.
25
-
26
  The credential rotation strategy can be configured via the `rotation_tolerance` parameter:
27
-
28
  - **tolerance = 0.0**: Deterministic least-used selection. The credential with
29
  the lowest usage count is always selected. This provides predictable, perfectly balanced
30
  load distribution but may be vulnerable to fingerprinting.
31
-
32
  - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected
33
  randomly with weights biased toward less-used ones. Credentials within 2 uses of the
34
  maximum can still be selected with reasonable probability. This provides security through
35
  unpredictability while maintaining good load balance.
36
-
37
  - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant
38
  selection probability. Useful for stress testing or maximum unpredictability, but may
39
  result in less balanced load distribution.
40
-
41
  The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1`
42
-
43
  This ensures lower-usage credentials are preferred while tolerance controls how much
44
  randomness is introduced into the selection process.
45
  """
@@ -52,7 +52,7 @@ class UsageManager:
52
  ):
53
  """
54
  Initialize the UsageManager.
55
-
56
  Args:
57
  file_path: Path to the usage data JSON file
58
  daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
@@ -139,7 +139,9 @@ class UsageManager:
139
  last_reset_dt is None
140
  or last_reset_dt < reset_threshold_today <= now_utc
141
  ):
142
- lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}")
 
 
143
  needs_saving = True
144
 
145
  # Reset cooldowns
@@ -194,24 +196,20 @@ class UsageManager:
194
  "models_in_use": {}, # Dict[model_name, concurrent_count]
195
  }
196
 
197
- def _select_weighted_random(
198
- self,
199
- candidates: List[tuple],
200
- tolerance: float
201
- ) -> str:
202
  """
203
  Selects a credential using weighted random selection based on usage counts.
204
-
205
  Args:
206
  candidates: List of (credential_id, usage_count) tuples
207
  tolerance: Tolerance value for weight calculation
208
-
209
  Returns:
210
  Selected credential ID
211
-
212
  Formula:
213
  weight = (max_usage - credential_usage) + tolerance + 1
214
-
215
  This formula ensures:
216
  - Lower usage = higher weight = higher selection probability
217
  - Tolerance adds variability: higher tolerance means more randomness
@@ -219,63 +217,66 @@ class UsageManager:
219
  """
220
  if not candidates:
221
  raise ValueError("Cannot select from empty candidate list")
222
-
223
  if len(candidates) == 1:
224
  return candidates[0][0]
225
-
226
  # Extract usage counts
227
  usage_counts = [usage for _, usage in candidates]
228
  max_usage = max(usage_counts)
229
-
230
  # Calculate weights using the formula: (max - current) + tolerance + 1
231
  weights = []
232
  for credential, usage in candidates:
233
  weight = (max_usage - usage) + tolerance + 1
234
  weights.append(weight)
235
-
236
  # Log weight distribution for debugging
237
  if lib_logger.isEnabledFor(logging.DEBUG):
238
  total_weight = sum(weights)
239
  weight_info = ", ".join(
240
- f"{mask_credential(cred)}: w={w:.1f} ({w/total_weight*100:.1f}%)"
241
  for (cred, _), w in zip(candidates, weights)
242
  )
243
- #lib_logger.debug(f"Weighted selection candidates: {weight_info}")
244
-
245
  # Random selection with weights
246
  selected_credential = random.choices(
247
- [cred for cred, _ in candidates],
248
- weights=weights,
249
- k=1
250
  )[0]
251
-
252
  return selected_credential
253
 
254
  async def acquire_key(
255
- self, available_keys: List[str], model: str, deadline: float,
 
 
 
256
  max_concurrent: int = 1,
257
- credential_priorities: Optional[Dict[str, int]] = None
 
258
  ) -> str:
259
  """
260
  Acquires the best available key using a tiered, model-aware locking strategy,
261
  respecting a global deadline and credential priorities.
262
-
263
  Priority Logic:
264
  - Groups credentials by priority level (1=highest, 2=lower, etc.)
265
  - Always tries highest priority (lowest number) first
266
  - Within same priority, sorts by usage count (load balancing)
267
  - Only moves to next priority if all higher-priority keys exhausted/busy
268
-
269
  Args:
270
  available_keys: List of credential identifiers to choose from
271
  model: Model name being requested
272
  deadline: Timestamp after which to stop trying
273
  max_concurrent: Maximum concurrent requests allowed per credential
274
  credential_priorities: Optional dict mapping credentials to priority levels (1=highest)
275
-
 
276
  Returns:
277
  Selected credential identifier
278
-
279
  Raises:
280
  NoAvailableKeysError: If no key could be acquired within the deadline
281
  """
@@ -294,16 +295,16 @@ class UsageManager:
294
  async with self._data_lock:
295
  for key in available_keys:
296
  key_data = self._usage_data.get(key, {})
297
-
298
  # Skip keys on cooldown
299
  if (key_data.get("key_cooldown_until") or 0) > now or (
300
  key_data.get("model_cooldowns", {}).get(model) or 0
301
  ) > now:
302
  continue
303
-
304
  # Get priority for this key (default to 999 if not specified)
305
  priority = credential_priorities.get(key, 999)
306
-
307
  # Get usage count for load balancing within priority groups
308
  usage_count = (
309
  key_data.get("daily", {})
@@ -311,58 +312,75 @@ class UsageManager:
311
  .get(model, {})
312
  .get("success_count", 0)
313
  )
314
-
315
  # Group by priority
316
  if priority not in priority_groups:
317
  priority_groups[priority] = []
318
  priority_groups[priority].append((key, usage_count))
319
-
320
  # Try priority groups in order (1, 2, 3, ...)
321
  sorted_priorities = sorted(priority_groups.keys())
322
-
323
  for priority_level in sorted_priorities:
324
  keys_in_priority = priority_groups[priority_level]
325
-
326
  # Within each priority group, use existing tier1/tier2 logic
327
  tier1_keys, tier2_keys = [], []
328
  for key, usage_count in keys_in_priority:
329
  key_state = self.key_states[key]
330
-
331
  # Tier 1: Completely idle keys (preferred)
332
  if not key_state["models_in_use"]:
333
  tier1_keys.append((key, usage_count))
334
  # Tier 2: Keys that can accept more concurrent requests
335
  elif key_state["models_in_use"].get(model, 0) < max_concurrent:
336
  tier2_keys.append((key, usage_count))
337
-
338
  # Apply weighted random selection or deterministic sorting
339
- selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used"
340
-
 
 
 
 
341
  if self.rotation_tolerance > 0:
342
  # Weighted random selection within each tier
343
  if tier1_keys:
344
- selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance)
345
- tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key]
 
 
 
 
346
  if tier2_keys:
347
- selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance)
348
- tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key]
 
 
 
 
349
  else:
350
  # Deterministic: sort by usage within each tier
351
  tier1_keys.sort(key=lambda x: x[1])
352
  tier2_keys.sort(key=lambda x: x[1])
353
-
354
  # Try to acquire from Tier 1 first
355
  for key, usage in tier1_keys:
356
  state = self.key_states[key]
357
  async with state["lock"]:
358
  if not state["models_in_use"]:
359
  state["models_in_use"][model] = 1
 
 
 
 
 
360
  lib_logger.info(
361
- f"Acquired Priority-{priority_level} Tier-1 key {mask_credential(key)} for model {model} "
362
- f"(selection: {selection_method}, usage: {usage})"
363
  )
364
  return key
365
-
366
  # Then try Tier 2
367
  for key, usage in tier2_keys:
368
  state = self.key_states[key]
@@ -370,35 +388,40 @@ class UsageManager:
370
  current_count = state["models_in_use"].get(model, 0)
371
  if current_count < max_concurrent:
372
  state["models_in_use"][model] = current_count + 1
 
 
 
 
 
373
  lib_logger.info(
374
- f"Acquired Priority-{priority_level} Tier-2 key {mask_credential(key)} for model {model} "
375
- f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})"
376
  )
377
  return key
378
-
379
  # If we get here, all priority groups were exhausted but keys might become available
380
  # Collect all keys across all priorities for waiting
381
  all_potential_keys = []
382
  for keys_list in priority_groups.values():
383
  all_potential_keys.extend(keys_list)
384
-
385
  if not all_potential_keys:
386
  lib_logger.warning(
387
  "No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating."
388
  )
389
  await asyncio.sleep(1)
390
  continue
391
-
392
  # Wait for the highest priority key with lowest usage
393
  best_priority = min(priority_groups.keys())
394
  best_priority_keys = priority_groups[best_priority]
395
  best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0]
396
  wait_condition = self.key_states[best_wait_key]["condition"]
397
-
398
  lib_logger.info(
399
  f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..."
400
  )
401
-
402
  else:
403
  # Original logic when no priorities specified
404
  tier1_keys, tier2_keys = [], []
@@ -430,16 +453,26 @@ class UsageManager:
430
  tier2_keys.append((key, usage_count))
431
 
432
  # Apply weighted random selection or deterministic sorting
433
- selection_method = "weighted-random" if self.rotation_tolerance > 0 else "least-used"
434
-
 
 
435
  if self.rotation_tolerance > 0:
436
  # Weighted random selection within each tier
437
  if tier1_keys:
438
- selected_key = self._select_weighted_random(tier1_keys, self.rotation_tolerance)
439
- tier1_keys = [(k, u) for k, u in tier1_keys if k == selected_key]
 
 
 
 
440
  if tier2_keys:
441
- selected_key = self._select_weighted_random(tier2_keys, self.rotation_tolerance)
442
- tier2_keys = [(k, u) for k, u in tier2_keys if k == selected_key]
 
 
 
 
443
  else:
444
  # Deterministic: sort by usage within each tier
445
  tier1_keys.sort(key=lambda x: x[1])
@@ -451,9 +484,15 @@ class UsageManager:
451
  async with state["lock"]:
452
  if not state["models_in_use"]:
453
  state["models_in_use"][model] = 1
 
 
 
 
 
 
454
  lib_logger.info(
455
- f"Acquired Tier 1 key {mask_credential(key)} for model {model} "
456
- f"(selection: {selection_method}, usage: {usage})"
457
  )
458
  return key
459
 
@@ -464,9 +503,15 @@ class UsageManager:
464
  current_count = state["models_in_use"].get(model, 0)
465
  if current_count < max_concurrent:
466
  state["models_in_use"][model] = current_count + 1
 
 
 
 
 
 
467
  lib_logger.info(
468
- f"Acquired Tier 2 key {mask_credential(key)} for model {model} "
469
- f"(selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})"
470
  )
471
  return key
472
 
@@ -506,8 +551,6 @@ class UsageManager:
506
  f"Could not acquire a key for model {model} within the global time budget."
507
  )
508
 
509
-
510
-
511
  async def release_key(self, key: str, model: str):
512
  """Releases a key's lock for a specific model and notifies waiting tasks."""
513
  if key not in self.key_states:
@@ -640,8 +683,11 @@ class UsageManager:
640
  await self._save_usage()
641
 
642
  async def record_failure(
643
- self, key: str, model: str, classified_error: ClassifiedError,
644
- increment_consecutive_failures: bool = True
 
 
 
645
  ):
646
  """Records a failure and applies cooldowns based on an escalating backoff strategy.
647
 
@@ -705,7 +751,9 @@ class UsageManager:
705
  # If cooldown wasn't set by specific error type, use escalating backoff
706
  if cooldown_seconds is None:
707
  backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120}
708
- cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours for "spent" keys
 
 
709
  lib_logger.warning(
710
  f"Failure #{count} for key {mask_credential(key)} with model {model}. "
711
  f"Error type: {classified_error.error_type}"
 
22
  """
23
  Manages usage statistics and cooldowns for API keys with asyncio-safe locking,
24
  asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation.
25
+
26
  The credential rotation strategy can be configured via the `rotation_tolerance` parameter:
27
+
28
  - **tolerance = 0.0**: Deterministic least-used selection. The credential with
29
  the lowest usage count is always selected. This provides predictable, perfectly balanced
30
  load distribution but may be vulnerable to fingerprinting.
31
+
32
  - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected
33
  randomly with weights biased toward less-used ones. Credentials within 2 uses of the
34
  maximum can still be selected with reasonable probability. This provides security through
35
  unpredictability while maintaining good load balance.
36
+
37
  - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant
38
  selection probability. Useful for stress testing or maximum unpredictability, but may
39
  result in less balanced load distribution.
40
+
41
  The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1`
42
+
43
  This ensures lower-usage credentials are preferred while tolerance controls how much
44
  randomness is introduced into the selection process.
45
  """
 
52
  ):
53
  """
54
  Initialize the UsageManager.
55
+
56
  Args:
57
  file_path: Path to the usage data JSON file
58
  daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
 
139
  last_reset_dt is None
140
  or last_reset_dt < reset_threshold_today <= now_utc
141
  ):
142
+ lib_logger.debug(
143
+ f"Performing daily reset for key {mask_credential(key)}"
144
+ )
145
  needs_saving = True
146
 
147
  # Reset cooldowns
 
196
  "models_in_use": {}, # Dict[model_name, concurrent_count]
197
  }
198
 
199
+ def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str:
 
 
 
 
200
  """
201
  Selects a credential using weighted random selection based on usage counts.
202
+
203
  Args:
204
  candidates: List of (credential_id, usage_count) tuples
205
  tolerance: Tolerance value for weight calculation
206
+
207
  Returns:
208
  Selected credential ID
209
+
210
  Formula:
211
  weight = (max_usage - credential_usage) + tolerance + 1
212
+
213
  This formula ensures:
214
  - Lower usage = higher weight = higher selection probability
215
  - Tolerance adds variability: higher tolerance means more randomness
 
217
  """
218
  if not candidates:
219
  raise ValueError("Cannot select from empty candidate list")
220
+
221
  if len(candidates) == 1:
222
  return candidates[0][0]
223
+
224
  # Extract usage counts
225
  usage_counts = [usage for _, usage in candidates]
226
  max_usage = max(usage_counts)
227
+
228
  # Calculate weights using the formula: (max - current) + tolerance + 1
229
  weights = []
230
  for credential, usage in candidates:
231
  weight = (max_usage - usage) + tolerance + 1
232
  weights.append(weight)
233
+
234
  # Log weight distribution for debugging
235
  if lib_logger.isEnabledFor(logging.DEBUG):
236
  total_weight = sum(weights)
237
  weight_info = ", ".join(
238
+ f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)"
239
  for (cred, _), w in zip(candidates, weights)
240
  )
241
+ # lib_logger.debug(f"Weighted selection candidates: {weight_info}")
242
+
243
  # Random selection with weights
244
  selected_credential = random.choices(
245
+ [cred for cred, _ in candidates], weights=weights, k=1
 
 
246
  )[0]
247
+
248
  return selected_credential
249
 
250
  async def acquire_key(
251
+ self,
252
+ available_keys: List[str],
253
+ model: str,
254
+ deadline: float,
255
  max_concurrent: int = 1,
256
+ credential_priorities: Optional[Dict[str, int]] = None,
257
+ credential_tier_names: Optional[Dict[str, str]] = None,
258
  ) -> str:
259
  """
260
  Acquires the best available key using a tiered, model-aware locking strategy,
261
  respecting a global deadline and credential priorities.
262
+
263
  Priority Logic:
264
  - Groups credentials by priority level (1=highest, 2=lower, etc.)
265
  - Always tries highest priority (lowest number) first
266
  - Within same priority, sorts by usage count (load balancing)
267
  - Only moves to next priority if all higher-priority keys exhausted/busy
268
+
269
  Args:
270
  available_keys: List of credential identifiers to choose from
271
  model: Model name being requested
272
  deadline: Timestamp after which to stop trying
273
  max_concurrent: Maximum concurrent requests allowed per credential
274
  credential_priorities: Optional dict mapping credentials to priority levels (1=highest)
275
+ credential_tier_names: Optional dict mapping credentials to tier names (for logging)
276
+
277
  Returns:
278
  Selected credential identifier
279
+
280
  Raises:
281
  NoAvailableKeysError: If no key could be acquired within the deadline
282
  """
 
295
  async with self._data_lock:
296
  for key in available_keys:
297
  key_data = self._usage_data.get(key, {})
298
+
299
  # Skip keys on cooldown
300
  if (key_data.get("key_cooldown_until") or 0) > now or (
301
  key_data.get("model_cooldowns", {}).get(model) or 0
302
  ) > now:
303
  continue
304
+
305
  # Get priority for this key (default to 999 if not specified)
306
  priority = credential_priorities.get(key, 999)
307
+
308
  # Get usage count for load balancing within priority groups
309
  usage_count = (
310
  key_data.get("daily", {})
 
312
  .get(model, {})
313
  .get("success_count", 0)
314
  )
315
+
316
  # Group by priority
317
  if priority not in priority_groups:
318
  priority_groups[priority] = []
319
  priority_groups[priority].append((key, usage_count))
320
+
321
  # Try priority groups in order (1, 2, 3, ...)
322
  sorted_priorities = sorted(priority_groups.keys())
323
+
324
  for priority_level in sorted_priorities:
325
  keys_in_priority = priority_groups[priority_level]
326
+
327
  # Within each priority group, use existing tier1/tier2 logic
328
  tier1_keys, tier2_keys = [], []
329
  for key, usage_count in keys_in_priority:
330
  key_state = self.key_states[key]
331
+
332
  # Tier 1: Completely idle keys (preferred)
333
  if not key_state["models_in_use"]:
334
  tier1_keys.append((key, usage_count))
335
  # Tier 2: Keys that can accept more concurrent requests
336
  elif key_state["models_in_use"].get(model, 0) < max_concurrent:
337
  tier2_keys.append((key, usage_count))
338
+
339
  # Apply weighted random selection or deterministic sorting
340
+ selection_method = (
341
+ "weighted-random"
342
+ if self.rotation_tolerance > 0
343
+ else "least-used"
344
+ )
345
+
346
  if self.rotation_tolerance > 0:
347
  # Weighted random selection within each tier
348
  if tier1_keys:
349
+ selected_key = self._select_weighted_random(
350
+ tier1_keys, self.rotation_tolerance
351
+ )
352
+ tier1_keys = [
353
+ (k, u) for k, u in tier1_keys if k == selected_key
354
+ ]
355
  if tier2_keys:
356
+ selected_key = self._select_weighted_random(
357
+ tier2_keys, self.rotation_tolerance
358
+ )
359
+ tier2_keys = [
360
+ (k, u) for k, u in tier2_keys if k == selected_key
361
+ ]
362
  else:
363
  # Deterministic: sort by usage within each tier
364
  tier1_keys.sort(key=lambda x: x[1])
365
  tier2_keys.sort(key=lambda x: x[1])
366
+
367
  # Try to acquire from Tier 1 first
368
  for key, usage in tier1_keys:
369
  state = self.key_states[key]
370
  async with state["lock"]:
371
  if not state["models_in_use"]:
372
  state["models_in_use"][model] = 1
373
+ tier_name = (
374
+ credential_tier_names.get(key, "unknown")
375
+ if credential_tier_names
376
+ else "unknown"
377
+ )
378
  lib_logger.info(
379
+ f"Acquired key {mask_credential(key)} for model {model} "
380
+ f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, usage: {usage})"
381
  )
382
  return key
383
+
384
  # Then try Tier 2
385
  for key, usage in tier2_keys:
386
  state = self.key_states[key]
 
388
  current_count = state["models_in_use"].get(model, 0)
389
  if current_count < max_concurrent:
390
  state["models_in_use"][model] = current_count + 1
391
+ tier_name = (
392
+ credential_tier_names.get(key, "unknown")
393
+ if credential_tier_names
394
+ else "unknown"
395
+ )
396
  lib_logger.info(
397
+ f"Acquired key {mask_credential(key)} for model {model} "
398
+ f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})"
399
  )
400
  return key
401
+
402
  # If we get here, all priority groups were exhausted but keys might become available
403
  # Collect all keys across all priorities for waiting
404
  all_potential_keys = []
405
  for keys_list in priority_groups.values():
406
  all_potential_keys.extend(keys_list)
407
+
408
  if not all_potential_keys:
409
  lib_logger.warning(
410
  "No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating."
411
  )
412
  await asyncio.sleep(1)
413
  continue
414
+
415
  # Wait for the highest priority key with lowest usage
416
  best_priority = min(priority_groups.keys())
417
  best_priority_keys = priority_groups[best_priority]
418
  best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0]
419
  wait_condition = self.key_states[best_wait_key]["condition"]
420
+
421
  lib_logger.info(
422
  f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..."
423
  )
424
+
425
  else:
426
  # Original logic when no priorities specified
427
  tier1_keys, tier2_keys = [], []
 
453
  tier2_keys.append((key, usage_count))
454
 
455
  # Apply weighted random selection or deterministic sorting
456
+ selection_method = (
457
+ "weighted-random" if self.rotation_tolerance > 0 else "least-used"
458
+ )
459
+
460
  if self.rotation_tolerance > 0:
461
  # Weighted random selection within each tier
462
  if tier1_keys:
463
+ selected_key = self._select_weighted_random(
464
+ tier1_keys, self.rotation_tolerance
465
+ )
466
+ tier1_keys = [
467
+ (k, u) for k, u in tier1_keys if k == selected_key
468
+ ]
469
  if tier2_keys:
470
+ selected_key = self._select_weighted_random(
471
+ tier2_keys, self.rotation_tolerance
472
+ )
473
+ tier2_keys = [
474
+ (k, u) for k, u in tier2_keys if k == selected_key
475
+ ]
476
  else:
477
  # Deterministic: sort by usage within each tier
478
  tier1_keys.sort(key=lambda x: x[1])
 
484
  async with state["lock"]:
485
  if not state["models_in_use"]:
486
  state["models_in_use"][model] = 1
487
+ tier_name = (
488
+ credential_tier_names.get(key)
489
+ if credential_tier_names
490
+ else None
491
+ )
492
+ tier_info = f"tier: {tier_name}, " if tier_name else ""
493
  lib_logger.info(
494
+ f"Acquired key {mask_credential(key)} for model {model} "
495
+ f"({tier_info}selection: {selection_method}, usage: {usage})"
496
  )
497
  return key
498
 
 
503
  current_count = state["models_in_use"].get(model, 0)
504
  if current_count < max_concurrent:
505
  state["models_in_use"][model] = current_count + 1
506
+ tier_name = (
507
+ credential_tier_names.get(key)
508
+ if credential_tier_names
509
+ else None
510
+ )
511
+ tier_info = f"tier: {tier_name}, " if tier_name else ""
512
  lib_logger.info(
513
+ f"Acquired key {mask_credential(key)} for model {model} "
514
+ f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{max_concurrent}, usage: {usage})"
515
  )
516
  return key
517
 
 
551
  f"Could not acquire a key for model {model} within the global time budget."
552
  )
553
 
 
 
554
  async def release_key(self, key: str, model: str):
555
  """Releases a key's lock for a specific model and notifies waiting tasks."""
556
  if key not in self.key_states:
 
683
  await self._save_usage()
684
 
685
  async def record_failure(
686
+ self,
687
+ key: str,
688
+ model: str,
689
+ classified_error: ClassifiedError,
690
+ increment_consecutive_failures: bool = True,
691
  ):
692
  """Records a failure and applies cooldowns based on an escalating backoff strategy.
693
 
 
751
  # If cooldown wasn't set by specific error type, use escalating backoff
752
  if cooldown_seconds is None:
753
  backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120}
754
+ cooldown_seconds = backoff_tiers.get(
755
+ count, 7200
756
+ ) # Default to 2 hours for "spent" keys
757
  lib_logger.warning(
758
  f"Failure #{count} for key {mask_credential(key)} with model {model}. "
759
  f"Error type: {classified_error.error_type}"