import json import os import time import logging import asyncio import random from datetime import date, datetime, timezone, time as dt_time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union import aiofiles import litellm from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential from .providers import PROVIDER_PLUGINS from .utils.resilient_io import ResilientStateWriter from .utils.paths import get_data_file lib_logger = logging.getLogger("rotator_library") lib_logger.propagate = False if not lib_logger.handlers: lib_logger.addHandler(logging.NullHandler()) class UsageManager: """ Manages usage statistics and cooldowns for API keys with asyncio-safe locking, asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - **tolerance = 0.0**: Deterministic least-used selection. The credential with the lowest usage count is always selected. This provides predictable, perfectly balanced load distribution but may be vulnerable to fingerprinting. - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected randomly with weights biased toward less-used ones. Credentials within 2 uses of the maximum can still be selected with reasonable probability. This provides security through unpredictability while maintaining good load balance. - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant selection probability. Useful for stress testing or maximum unpredictability, but may result in less balanced load distribution. The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` This ensures lower-usage credentials are preferred while tolerance controls how much randomness is introduced into the selection process. Additionally, providers can specify a rotation mode: - "balanced" (default): Rotate credentials to distribute load evenly - "sequential": Use one credential until exhausted (preserves caching) """ def __init__( self, file_path: Optional[Union[str, Path]] = None, daily_reset_time_utc: Optional[str] = "03:00", rotation_tolerance: float = 0.0, provider_rotation_modes: Optional[Dict[str, str]] = None, provider_plugins: Optional[Dict[str, Any]] = None, priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, priority_multipliers_by_mode: Optional[ Dict[str, Dict[str, Dict[int, int]]] ] = None, sequential_fallback_multipliers: Optional[Dict[str, int]] = None, ): """ Initialize the UsageManager. Args: file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). Can be absolute Path, relative Path, or string. daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) rotation_tolerance: Tolerance for weighted random credential rotation. - 0.0: Deterministic, least-used credential always selected - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - 5.0+: High randomness, more unpredictable selection patterns provider_rotation_modes: Dict mapping provider names to rotation modes. - "balanced": Rotate credentials to distribute load evenly (default) - "sequential": Use one credential until exhausted (preserves caching) provider_plugins: Dict mapping provider names to provider plugin instances. Used for per-provider usage reset configuration (window durations, field names). priority_multipliers: Dict mapping provider -> priority -> multiplier. Universal multipliers that apply regardless of rotation mode. Example: {"antigravity": {1: 5, 2: 3}} priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. Used in sequential mode when priority not in priority_multipliers. Example: {"antigravity": 2} """ # Resolve file_path - use default if not provided if file_path is None: self.file_path = str(get_data_file("key_usage.json")) elif isinstance(file_path, Path): self.file_path = str(file_path) else: # String path - could be relative or absolute self.file_path = file_path self.rotation_tolerance = rotation_tolerance self.provider_rotation_modes = provider_rotation_modes or {} self.provider_plugins = provider_plugins or PROVIDER_PLUGINS self.priority_multipliers = priority_multipliers or {} self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} self._provider_instances: Dict[str, Any] = {} # Cache for provider instances self.key_states: Dict[str, Dict[str, Any]] = {} self._data_lock = asyncio.Lock() self._usage_data: Optional[Dict] = None self._initialized = asyncio.Event() self._init_lock = asyncio.Lock() self._timeout_lock = asyncio.Lock() self._claimed_on_timeout: Set[str] = set() # Resilient writer for usage data persistence self._state_writer = ResilientStateWriter(file_path, lib_logger) if daily_reset_time_utc: hour, minute = map(int, daily_reset_time_utc.split(":")) self.daily_reset_time_utc = dt_time( hour=hour, minute=minute, tzinfo=timezone.utc ) else: self.daily_reset_time_utc = None def _get_rotation_mode(self, provider: str) -> str: """ Get the rotation mode for a provider. Args: provider: Provider name (e.g., "antigravity", "gemini_cli") Returns: "balanced" or "sequential" """ return self.provider_rotation_modes.get(provider, "balanced") def _get_priority_multiplier( self, provider: str, priority: int, rotation_mode: str ) -> int: """ Get the concurrency multiplier for a provider/priority/mode combination. Lookup order: 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] 2. Universal tier multiplier: priority_multipliers[provider][priority] 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] 4. Global default: 1 (no multiplier effect) Args: provider: Provider name (e.g., "antigravity") priority: Priority level (1 = highest priority) rotation_mode: Current rotation mode ("sequential" or "balanced") Returns: Multiplier value """ provider_lower = provider.lower() # 1. Check mode-specific override if provider_lower in self.priority_multipliers_by_mode: mode_multipliers = self.priority_multipliers_by_mode[provider_lower] if rotation_mode in mode_multipliers: if priority in mode_multipliers[rotation_mode]: return mode_multipliers[rotation_mode][priority] # 2. Check universal tier multiplier if provider_lower in self.priority_multipliers: if priority in self.priority_multipliers[provider_lower]: return self.priority_multipliers[provider_lower][priority] # 3. Sequential fallback (only for sequential mode) if rotation_mode == "sequential": if provider_lower in self.sequential_fallback_multipliers: return self.sequential_fallback_multipliers[provider_lower] # 4. Global default return 1 def _get_provider_from_credential(self, credential: str) -> Optional[str]: """ Extract provider name from credential path or identifier. Supports multiple credential formats: - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" - API key style: stored with provider prefix metadata Args: credential: The credential identifier (path or key) Returns: Provider name string or None if cannot be determined """ import re # Normalize path separators normalized = credential.replace("\\", "/") # Pattern: path ending with {provider}_oauth_{number}.json match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) if match: return match.group(1).lower() # Pattern: oauth_creds/{provider}_... match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) if match: return match.group(1).lower() # Pattern: filename only {provider}_oauth_{number}.json (no path) match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) if match: return match.group(1).lower() return None def _get_provider_instance(self, provider: str) -> Optional[Any]: """ Get or create a provider plugin instance. Args: provider: The provider name Returns: Provider plugin instance or None """ if not provider: return None plugin_class = self.provider_plugins.get(provider) if not plugin_class: return None # Get or create provider instance from cache if provider not in self._provider_instances: # Instantiate the plugin if it's a class, or use it directly if already an instance if isinstance(plugin_class, type): self._provider_instances[provider] = plugin_class() else: self._provider_instances[provider] = plugin_class return self._provider_instances[provider] def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: """ Get the usage reset configuration for a credential from its provider plugin. Args: credential: The credential identifier Returns: Configuration dict with window_seconds, field_name, etc. or None to use default daily reset. """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): return plugin_instance.get_usage_reset_config(credential) return None def _get_reset_mode(self, credential: str) -> str: """ Get the reset mode for a credential: 'credential' or 'per_model'. Args: credential: The credential identifier Returns: "per_model" or "credential" (default) """ config = self._get_usage_reset_config(credential) return config.get("mode", "credential") if config else "credential" def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: """ Get the quota group for a model, if the provider defines one. Args: credential: The credential identifier model: Model name (with or without provider prefix) Returns: Group name (e.g., "claude") or None if not grouped """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): return plugin_instance.get_model_quota_group(model) return None def _get_grouped_models(self, credential: str, group: str) -> List[str]: """ Get all model names in a quota group (with provider prefix). Args: credential: The credential identifier group: Group name (e.g., "claude") Returns: List of full model names (e.g., ["antigravity/claude-opus-4-5", ...]) """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): models = plugin_instance.get_models_in_quota_group(group) # Add provider prefix return [f"{provider}/{m}" for m in models] return [] def _get_model_usage_weight(self, credential: str, model: str) -> int: """ Get the usage weight for a model when calculating grouped usage. Args: credential: The credential identifier model: Model name (with or without provider prefix) Returns: Weight multiplier (default 1 if not configured) """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): return plugin_instance.get_model_usage_weight(model) return 1 # Providers where request_count should be used for credential selection # instead of success_count (because failed requests also consume quota) _REQUEST_COUNT_PROVIDERS = {"antigravity"} def _get_grouped_usage_count(self, key: str, model: str) -> int: """ Get usage count for credential selection, considering quota groups. For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses request_count instead of success_count since failed requests also consume quota. If the model belongs to a quota group, the request_count is already synced across all models in the group (by record_success/record_failure), so we just read from the requested model directly. Args: key: Credential identifier model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") Returns: Usage count for the model (synced across group if applicable) """ # Determine usage field based on provider # Some providers (antigravity) count failed requests against quota provider = self._get_provider_from_credential(key) usage_field = ( "request_count" if provider in self._REQUEST_COUNT_PROVIDERS else "success_count" ) # For providers with synced quota groups (antigravity), request_count # is already synced across all models in the group, so just read directly. # For other providers, we still need to sum success_count across group. if provider in self._REQUEST_COUNT_PROVIDERS: # request_count is synced - just read the model's value return self._get_usage_count(key, model, usage_field) # For non-synced providers, check if model is in a quota group and sum group = self._get_model_quota_group(key, model) if group: # Get all models in the group grouped_models = self._get_grouped_models(key, group) # Sum weighted usage across all models in the group total_weighted_usage = 0 for grouped_model in grouped_models: usage = self._get_usage_count(key, grouped_model, usage_field) weight = self._get_model_usage_weight(key, grouped_model) total_weighted_usage += usage * weight return total_weighted_usage # Not grouped - return individual model usage (no weight applied) return self._get_usage_count(key, model, usage_field) def _get_quota_display(self, key: str, model: str) -> str: """ Get a formatted quota display string for logging. For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: "quota: 170/250 [32%]" format For other providers, returns: "usage: 170" format (no max available) Args: key: Credential identifier model: Model name (with provider prefix) Returns: Formatted string for logging """ provider = self._get_provider_from_credential(key) if provider not in self._REQUEST_COUNT_PROVIDERS: # Non-antigravity: just show usage count usage = self._get_usage_count(key, model, "success_count") return f"usage: {usage}" # Antigravity: show quota display with remaining percentage if self._usage_data is None: return "quota: 0/? [100%]" key_data = self._usage_data.get(key, {}) model_data = key_data.get("models", {}).get(model, {}) request_count = model_data.get("request_count", 0) max_requests = model_data.get("quota_max_requests") if max_requests: remaining = max_requests - request_count remaining_pct = ( int((remaining / max_requests) * 100) if max_requests > 0 else 0 ) return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" else: return f"quota: {request_count}" def _get_usage_field_name(self, credential: str) -> str: """ Get the usage tracking field name for a credential. Returns the provider-specific field name if configured, otherwise falls back to "daily". Args: credential: The credential identifier Returns: Field name string (e.g., "5h_window", "weekly", "daily") """ config = self._get_usage_reset_config(credential) if config and "field_name" in config: return config["field_name"] # Check provider default provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): return plugin_instance.get_default_usage_field_name() return "daily" def _get_usage_count( self, key: str, model: str, field: str = "success_count" ) -> int: """ Get the current usage count for a model from the appropriate usage structure. Supports both: - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} Args: key: Credential identifier model: Model name field: The field to read for usage count (default: "success_count"). Use "request_count" for providers where failed requests also consume quota (e.g., antigravity). Returns: Usage count for the model in the current window/period """ if self._usage_data is None: return 0 key_data = self._usage_data.get(key, {}) reset_mode = self._get_reset_mode(key) if reset_mode == "per_model": # New per-model structure: key_data["models"][model][field] return key_data.get("models", {}).get(model, {}).get(field, 0) else: # Legacy structure: key_data["daily"]["models"][model][field] return ( key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) ) # ========================================================================= # TIMESTAMP FORMATTING HELPERS # ========================================================================= def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: """ Format Unix timestamp as local time string with timezone offset. Args: ts: Unix timestamp or None Returns: Formatted string like "2025-12-07 14:30:17 +0100" or None """ if ts is None: return None try: dt = datetime.fromtimestamp(ts).astimezone() # Local timezone # Use UTC offset for conciseness (works on all platforms) return dt.strftime("%Y-%m-%d %H:%M:%S %z") except (OSError, ValueError, OverflowError): return None def _add_readable_timestamps(self, data: Dict) -> Dict: """ Add human-readable timestamp fields to usage data before saving. Adds 'window_started' and 'quota_resets' fields derived from Unix timestamps for easier debugging and monitoring. Args: data: The usage data dict to enhance Returns: The same dict with readable timestamp fields added """ for key, key_data in data.items(): # Handle per-model structure models = key_data.get("models", {}) for model_name, model_stats in models.items(): if not isinstance(model_stats, dict): continue # Add readable window start time window_start = model_stats.get("window_start_ts") if window_start: model_stats["window_started"] = self._format_timestamp_local( window_start ) elif "window_started" in model_stats: del model_stats["window_started"] # Add readable reset time quota_reset = model_stats.get("quota_reset_ts") if quota_reset: model_stats["quota_resets"] = self._format_timestamp_local( quota_reset ) elif "quota_resets" in model_stats: del model_stats["quota_resets"] return data def _sort_sequential( self, candidates: List[Tuple[str, int]], credential_priorities: Optional[Dict[str, int]] = None, ) -> List[Tuple[str, int]]: """ Sort credentials for sequential mode with position retention. Credentials maintain their position based on established usage patterns, ensuring that actively-used credentials remain primary until exhausted. Sorting order (within each sort key, lower value = higher priority): 1. Priority tier (lower number = higher priority) 2. Usage count (higher = more established in rotation, maintains position) 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) 4. Credential ID (alphabetical, stable ordering) Args: candidates: List of (credential_id, usage_count) tuples credential_priorities: Optional dict mapping credentials to priority levels Returns: Sorted list of candidates (same format as input) """ if not candidates: return [] if len(candidates) == 1: return candidates def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: cred, usage_count = item priority = ( credential_priorities.get(cred, 999) if credential_priorities else 999 ) last_used = ( self._usage_data.get(cred, {}).get("last_used_ts", 0) if self._usage_data else 0 ) return ( priority, # ASC: lower priority number = higher priority -usage_count, # DESC: higher usage = more established -last_used, # DESC: more recent = preferred for ties cred, # ASC: stable alphabetical ordering ) sorted_candidates = sorted(candidates, key=sort_key) # Debug logging - show top 3 credentials in ordering if lib_logger.isEnabledFor(logging.DEBUG): order_info = [ f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" for c, u in sorted_candidates[:3] ] lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") return sorted_candidates async def _lazy_init(self): """Initializes the usage data by loading it from the file asynchronously.""" async with self._init_lock: if not self._initialized.is_set(): await self._load_usage() await self._reset_daily_stats_if_needed() self._initialized.set() async def _load_usage(self): """Loads usage data from the JSON file asynchronously with resilience.""" async with self._data_lock: if not os.path.exists(self.file_path): self._usage_data = {} return try: async with aiofiles.open(self.file_path, "r") as f: content = await f.read() self._usage_data = json.loads(content) if content.strip() else {} except FileNotFoundError: # File deleted between exists check and open self._usage_data = {} except json.JSONDecodeError as e: lib_logger.warning( f"Corrupted usage file {self.file_path}: {e}. Starting fresh." ) self._usage_data = {} except (OSError, PermissionError, IOError) as e: lib_logger.warning( f"Cannot read usage file {self.file_path}: {e}. Using empty state." ) self._usage_data = {} async def _save_usage(self): """Saves the current usage data using the resilient state writer.""" if self._usage_data is None: return async with self._data_lock: # Add human-readable timestamp fields before saving self._add_readable_timestamps(self._usage_data) # Hand off to resilient writer - handles retries and disk failures self._state_writer.write(self._usage_data) async def _get_usage_data_snapshot(self) -> Dict[str, Any]: """ Get a shallow copy of the current usage data. Returns: Copy of usage data dict (safe for reading without lock) """ await self._lazy_init() async with self._data_lock: return dict(self._usage_data) if self._usage_data else {} async def get_available_credentials_for_model( self, credentials: List[str], model: str ) -> List[str]: """ Get credentials that are not on cooldown for a specific model. Filters out credentials where: - key_cooldown_until > now (key-level cooldown) - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) Args: credentials: List of credential identifiers to check model: Model name to check cooldowns for Returns: List of credentials that are available (not on cooldown) for this model """ await self._lazy_init() now = time.time() available = [] async with self._data_lock: for key in credentials: key_data = self._usage_data.get(key, {}) # Skip if key-level cooldown is active if (key_data.get("key_cooldown_until") or 0) > now: continue # Skip if model-specific cooldown is active if (key_data.get("model_cooldowns", {}).get(model) or 0) > now: continue available.append(key) return available async def _reset_daily_stats_if_needed(self): """ Checks if usage stats need to be reset for any key. Supports three reset modes: 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window 2. credential: One window per credential (legacy with custom window duration) 3. daily: Legacy daily reset at daily_reset_time_utc """ if self._usage_data is None: return now_utc = datetime.now(timezone.utc) now_ts = time.time() today_str = now_utc.date().isoformat() needs_saving = False for key, data in self._usage_data.items(): reset_config = self._get_usage_reset_config(key) if reset_config: reset_mode = reset_config.get("mode", "credential") if reset_mode == "per_model": # Per-model window reset needs_saving |= await self._check_per_model_resets( key, data, reset_config, now_ts ) else: # Credential-level window reset (legacy) needs_saving |= await self._check_window_reset( key, data, reset_config, now_ts ) elif self.daily_reset_time_utc: # Legacy daily reset needs_saving |= await self._check_daily_reset( key, data, now_utc, today_str, now_ts ) if needs_saving: await self._save_usage() async def _check_per_model_resets( self, key: str, data: Dict[str, Any], reset_config: Dict[str, Any], now_ts: float, ) -> bool: """ Check and perform per-model resets for a credential. Each model resets independently based on: 1. quota_reset_ts (authoritative, from quota exhausted error) if set 2. window_start_ts + window_seconds (fallback) otherwise Grouped models reset together - all models in a group must be ready. Args: key: Credential identifier data: Usage data for this credential reset_config: Provider's reset configuration now_ts: Current timestamp Returns: True if data was modified and needs saving """ window_seconds = reset_config.get("window_seconds", 86400) models_data = data.get("models", {}) if not models_data: return False modified = False processed_groups = set() for model, model_data in list(models_data.items()): # Check if this model is in a quota group group = self._get_model_quota_group(key, model) if group: if group in processed_groups: continue # Already handled this group # Check if entire group should reset if self._should_group_reset( key, group, models_data, window_seconds, now_ts ): # Archive and reset all models in group grouped_models = self._get_grouped_models(key, group) archived_count = 0 for grouped_model in grouped_models: if grouped_model in models_data: gm_data = models_data[grouped_model] self._archive_model_to_global(data, grouped_model, gm_data) self._reset_model_data(gm_data) archived_count += 1 if archived_count > 0: lib_logger.info( f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" ) modified = True processed_groups.add(group) else: # Ungrouped model - check individually if self._should_model_reset(model_data, window_seconds, now_ts): self._archive_model_to_global(data, model, model_data) self._reset_model_data(model_data) lib_logger.info(f"Reset model {model} for {mask_credential(key)}") modified = True # Preserve unexpired cooldowns if modified: self._preserve_unexpired_cooldowns(key, data, now_ts) if "failures" in data: data["failures"] = {} return modified def _should_model_reset( self, model_data: Dict[str, Any], window_seconds: int, now_ts: float ) -> bool: """ Check if a single model should reset. Returns True if: - quota_reset_ts is set AND now >= quota_reset_ts, OR - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds """ quota_reset = model_data.get("quota_reset_ts") window_start = model_data.get("window_start_ts") if quota_reset: return now_ts >= quota_reset elif window_start: return now_ts >= window_start + window_seconds return False def _should_group_reset( self, key: str, group: str, models_data: Dict[str, Dict], window_seconds: int, now_ts: float, ) -> bool: """ Check if all models in a group should reset. All models in the group must be ready to reset. If any model has an active cooldown/window, the whole group waits. """ grouped_models = self._get_grouped_models(key, group) # Track if any model in group has data any_has_data = False for grouped_model in grouped_models: model_data = models_data.get(grouped_model, {}) if not model_data or ( model_data.get("window_start_ts") is None and model_data.get("success_count", 0) == 0 ): continue # No stats for this model yet any_has_data = True if not self._should_model_reset(model_data, window_seconds, now_ts): return False # At least one model not ready return any_has_data def _archive_model_to_global( self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] ) -> None: """Archive a single model's stats to global.""" global_data = data.setdefault("global", {"models": {}}) global_model = global_data["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) global_model["success_count"] += model_data.get("success_count", 0) global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) global_model["completion_tokens"] += model_data.get("completion_tokens", 0) global_model["approx_cost"] += model_data.get("approx_cost", 0.0) def _reset_model_data(self, model_data: Dict[str, Any]) -> None: """Reset a model's window and stats.""" model_data["window_start_ts"] = None model_data["quota_reset_ts"] = None model_data["success_count"] = 0 model_data["failure_count"] = 0 model_data["request_count"] = 0 model_data["prompt_tokens"] = 0 model_data["completion_tokens"] = 0 model_data["approx_cost"] = 0.0 # Reset quota baseline fields only if they exist (Antigravity-specific) # These are added by update_quota_baseline(), only called for Antigravity if "baseline_remaining_fraction" in model_data: model_data["baseline_remaining_fraction"] = None model_data["baseline_fetched_at"] = None model_data["requests_at_baseline"] = None # Reset quota display but keep max_requests (it doesn't change between periods) max_req = model_data.get("quota_max_requests") if max_req: model_data["quota_display"] = f"0/{max_req}" async def _check_window_reset( self, key: str, data: Dict[str, Any], reset_config: Dict[str, Any], now_ts: float, ) -> bool: """ Check and perform rolling window reset for a credential. Args: key: Credential identifier data: Usage data for this credential reset_config: Provider's reset configuration now_ts: Current timestamp Returns: True if data was modified and needs saving """ window_seconds = reset_config.get("window_seconds", 86400) # Default 24h field_name = reset_config.get("field_name", "window") description = reset_config.get("description", "rolling window") # Get current window data window_data = data.get(field_name, {}) window_start = window_data.get("start_ts") # No window started yet - nothing to reset if window_start is None: return False # Check if window has expired window_end = window_start + window_seconds if now_ts < window_end: # Window still active return False # Window expired - perform reset hours_elapsed = (now_ts - window_start) / 3600 lib_logger.info( f"Resetting {field_name} for {mask_credential(key)} - " f"{description} expired after {hours_elapsed:.1f}h" ) # Archive to global self._archive_to_global(data, window_data) # Preserve unexpired cooldowns self._preserve_unexpired_cooldowns(key, data, now_ts) # Reset window stats (but don't start new window until first request) data[field_name] = {"start_ts": None, "models": {}} # Reset consecutive failures if "failures" in data: data["failures"] = {} return True async def _check_daily_reset( self, key: str, data: Dict[str, Any], now_utc: datetime, today_str: str, now_ts: float, ) -> bool: """ Check and perform legacy daily reset for a credential. Args: key: Credential identifier data: Usage data for this credential now_utc: Current datetime in UTC today_str: Today's date as ISO string now_ts: Current timestamp Returns: True if data was modified and needs saving """ last_reset_str = data.get("last_daily_reset", "") if last_reset_str == today_str: return False last_reset_dt = None if last_reset_str: try: last_reset_dt = datetime.fromisoformat(last_reset_str).replace( tzinfo=timezone.utc ) except ValueError: pass # Determine the reset threshold for today reset_threshold_today = datetime.combine( now_utc.date(), self.daily_reset_time_utc ) if not ( last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc ): return False lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") # Preserve unexpired cooldowns self._preserve_unexpired_cooldowns(key, data, now_ts) # Reset consecutive failures if "failures" in data: data["failures"] = {} # Archive daily stats to global daily_data = data.get("daily", {}) if daily_data: self._archive_to_global(data, daily_data) # Reset daily stats data["daily"] = {"date": today_str, "models": {}} data["last_daily_reset"] = today_str return True def _archive_to_global( self, data: Dict[str, Any], source_data: Dict[str, Any] ) -> None: """ Archive usage stats from a source field (daily/window) to global. Args: data: The credential's usage data source_data: The source field data to archive (has "models" key) """ global_data = data.setdefault("global", {"models": {}}) for model, stats in source_data.get("models", {}).items(): global_model_stats = global_data["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) global_model_stats["success_count"] += stats.get("success_count", 0) global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) def _preserve_unexpired_cooldowns( self, key: str, data: Dict[str, Any], now_ts: float ) -> None: """ Preserve unexpired cooldowns during reset (important for long quota cooldowns). Args: key: Credential identifier (for logging) data: The credential's usage data now_ts: Current timestamp """ # Preserve unexpired model cooldowns if "model_cooldowns" in data: active_cooldowns = { model: end_time for model, end_time in data["model_cooldowns"].items() if end_time > now_ts } if active_cooldowns: max_remaining = max( end_time - now_ts for end_time in active_cooldowns.values() ) hours_remaining = max_remaining / 3600 lib_logger.info( f"Preserving {len(active_cooldowns)} active cooldown(s) " f"for key {mask_credential(key)} during reset " f"(longest: {hours_remaining:.1f}h remaining)" ) data["model_cooldowns"] = active_cooldowns else: data["model_cooldowns"] = {} # Preserve unexpired key-level cooldown if data.get("key_cooldown_until"): if data["key_cooldown_until"] <= now_ts: data["key_cooldown_until"] = None else: hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 lib_logger.info( f"Preserving key-level cooldown for {mask_credential(key)} " f"during reset ({hours_remaining:.1f}h remaining)" ) else: data["key_cooldown_until"] = None def _initialize_key_states(self, keys: List[str]): """Initializes state tracking for all provided keys if not already present.""" for key in keys: if key not in self.key_states: self.key_states[key] = { "lock": asyncio.Lock(), "condition": asyncio.Condition(), "models_in_use": {}, # Dict[model_name, concurrent_count] } def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: """ Selects a credential using weighted random selection based on usage counts. Args: candidates: List of (credential_id, usage_count) tuples tolerance: Tolerance value for weight calculation Returns: Selected credential ID Formula: weight = (max_usage - credential_usage) + tolerance + 1 This formula ensures: - Lower usage = higher weight = higher selection probability - Tolerance adds variability: higher tolerance means more randomness - The +1 ensures all credentials have at least some chance of selection """ if not candidates: raise ValueError("Cannot select from empty candidate list") if len(candidates) == 1: return candidates[0][0] # Extract usage counts usage_counts = [usage for _, usage in candidates] max_usage = max(usage_counts) # Calculate weights using the formula: (max - current) + tolerance + 1 weights = [] for credential, usage in candidates: weight = (max_usage - usage) + tolerance + 1 weights.append(weight) # Log weight distribution for debugging if lib_logger.isEnabledFor(logging.DEBUG): total_weight = sum(weights) weight_info = ", ".join( f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" for (cred, _), w in zip(candidates, weights) ) # lib_logger.debug(f"Weighted selection candidates: {weight_info}") # Random selection with weights selected_credential = random.choices( [cred for cred, _ in candidates], weights=weights, k=1 )[0] return selected_credential async def acquire_key( self, available_keys: List[str], model: str, deadline: float, max_concurrent: int = 1, credential_priorities: Optional[Dict[str, int]] = None, credential_tier_names: Optional[Dict[str, str]] = None, ) -> str: """ Acquires the best available key using a tiered, model-aware locking strategy, respecting a global deadline and credential priorities. Priority Logic: - Groups credentials by priority level (1=highest, 2=lower, etc.) - Always tries highest priority (lowest number) first - Within same priority, sorts by usage count (load balancing) - Only moves to next priority if all higher-priority keys exhausted/busy Args: available_keys: List of credential identifiers to choose from model: Model name being requested deadline: Timestamp after which to stop trying max_concurrent: Maximum concurrent requests allowed per credential credential_priorities: Optional dict mapping credentials to priority levels (1=highest) credential_tier_names: Optional dict mapping credentials to tier names (for logging) Returns: Selected credential identifier Raises: NoAvailableKeysError: If no key could be acquired within the deadline """ await self._lazy_init() await self._reset_daily_stats_if_needed() self._initialize_key_states(available_keys) # This loop continues as long as the global deadline has not been met. while time.time() < deadline: now = time.time() # Group credentials by priority level (if priorities provided) if credential_priorities: # Group keys by priority level priority_groups = {} async with self._data_lock: for key in available_keys: key_data = self._usage_data.get(key, {}) # Skip keys on cooldown if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(model) or 0 ) > now: continue # Get priority for this key (default to 999 if not specified) priority = credential_priorities.get(key, 999) # Get usage count for load balancing within priority groups # Uses grouped usage if model is in a quota group usage_count = self._get_grouped_usage_count(key, model) # Group by priority if priority not in priority_groups: priority_groups[priority] = [] priority_groups[priority].append((key, usage_count)) # Try priority groups in order (1, 2, 3, ...) sorted_priorities = sorted(priority_groups.keys()) for priority_level in sorted_priorities: keys_in_priority = priority_groups[priority_level] # Determine selection method based on provider's rotation mode provider = model.split("/")[0] if "/" in model else "" rotation_mode = self._get_rotation_mode(provider) # Calculate effective concurrency based on priority tier multiplier = self._get_priority_multiplier( provider, priority_level, rotation_mode ) effective_max_concurrent = max_concurrent * multiplier # Within each priority group, use existing tier1/tier2 logic tier1_keys, tier2_keys = [], [] for key, usage_count in keys_in_priority: key_state = self.key_states[key] # Tier 1: Completely idle keys (preferred) if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) # Tier 2: Keys that can accept more concurrent requests elif ( key_state["models_in_use"].get(model, 0) < effective_max_concurrent ): tier2_keys.append((key, usage_count)) if rotation_mode == "sequential": # Sequential mode: sort credentials by priority, usage, recency # Keep all candidates in sorted order (no filtering to single key) selection_method = "sequential" if tier1_keys: tier1_keys = self._sort_sequential( tier1_keys, credential_priorities ) if tier2_keys: tier2_keys = self._sort_sequential( tier2_keys, credential_priorities ) elif self.rotation_tolerance > 0: # Balanced mode with weighted randomness selection_method = "weighted-random" if tier1_keys: selected_key = self._select_weighted_random( tier1_keys, self.rotation_tolerance ) tier1_keys = [ (k, u) for k, u in tier1_keys if k == selected_key ] if tier2_keys: selected_key = self._select_weighted_random( tier2_keys, self.rotation_tolerance ) tier2_keys = [ (k, u) for k, u in tier2_keys if k == selected_key ] else: # Deterministic: sort by usage within each tier selection_method = "least-used" tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) # Try to acquire from Tier 1 first for key, usage in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 tier_name = ( credential_tier_names.get(key, "unknown") if credential_tier_names else "unknown" ) quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" ) return key # Then try Tier 2 for key, usage in tier2_keys: state = self.key_states[key] async with state["lock"]: current_count = state["models_in_use"].get(model, 0) if current_count < effective_max_concurrent: state["models_in_use"][model] = current_count + 1 tier_name = ( credential_tier_names.get(key, "unknown") if credential_tier_names else "unknown" ) quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" ) return key # If we get here, all priority groups were exhausted but keys might become available # Collect all keys across all priorities for waiting all_potential_keys = [] for keys_list in priority_groups.values(): all_potential_keys.extend(keys_list) if not all_potential_keys: lib_logger.warning( "No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating." ) await asyncio.sleep(1) continue # Wait for the highest priority key with lowest usage best_priority = min(priority_groups.keys()) best_priority_keys = priority_groups[best_priority] best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] lib_logger.info( f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." ) else: # Original logic when no priorities specified # Determine selection method based on provider's rotation mode provider = model.split("/")[0] if "/" in model else "" rotation_mode = self._get_rotation_mode(provider) # Calculate effective concurrency for default priority (999) # When no priorities are specified, all credentials get default priority default_priority = 999 multiplier = self._get_priority_multiplier( provider, default_priority, rotation_mode ) effective_max_concurrent = max_concurrent * multiplier tier1_keys, tier2_keys = [], [] # First, filter the list of available keys to exclude any on cooldown. async with self._data_lock: for key in available_keys: key_data = self._usage_data.get(key, {}) if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(model) or 0 ) > now: continue # Prioritize keys based on their current usage to ensure load balancing. # Uses grouped usage if model is in a quota group usage_count = self._get_grouped_usage_count(key, model) key_state = self.key_states[key] # Tier 1: Completely idle keys (preferred). if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) # Tier 2: Keys that can accept more concurrent requests for this model. elif ( key_state["models_in_use"].get(model, 0) < effective_max_concurrent ): tier2_keys.append((key, usage_count)) if rotation_mode == "sequential": # Sequential mode: sort credentials by priority, usage, recency # Keep all candidates in sorted order (no filtering to single key) selection_method = "sequential" if tier1_keys: tier1_keys = self._sort_sequential( tier1_keys, credential_priorities ) if tier2_keys: tier2_keys = self._sort_sequential( tier2_keys, credential_priorities ) elif self.rotation_tolerance > 0: # Balanced mode with weighted randomness selection_method = "weighted-random" if tier1_keys: selected_key = self._select_weighted_random( tier1_keys, self.rotation_tolerance ) tier1_keys = [ (k, u) for k, u in tier1_keys if k == selected_key ] if tier2_keys: selected_key = self._select_weighted_random( tier2_keys, self.rotation_tolerance ) tier2_keys = [ (k, u) for k, u in tier2_keys if k == selected_key ] else: # Deterministic: sort by usage within each tier selection_method = "least-used" tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) # Attempt to acquire a key from Tier 1 first. for key, usage in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 tier_name = ( credential_tier_names.get(key) if credential_tier_names else None ) tier_info = f"tier: {tier_name}, " if tier_name else "" quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"({tier_info}selection: {selection_method}, {quota_display})" ) return key # If no Tier 1 keys are available, try Tier 2. for key, usage in tier2_keys: state = self.key_states[key] async with state["lock"]: current_count = state["models_in_use"].get(model, 0) if current_count < effective_max_concurrent: state["models_in_use"][model] = current_count + 1 tier_name = ( credential_tier_names.get(key) if credential_tier_names else None ) tier_info = f"tier: {tier_name}, " if tier_name else "" quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" ) return key # If all eligible keys are locked, wait for a key to be released. lib_logger.info( "All eligible keys are currently locked for this model. Waiting..." ) all_potential_keys = tier1_keys + tier2_keys if not all_potential_keys: lib_logger.warning( "No keys are eligible (all on cooldown). Waiting before re-evaluating." ) await asyncio.sleep(1) continue # Wait on the condition of the key with the lowest current usage. best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] try: async with wait_condition: remaining_budget = deadline - time.time() if remaining_budget <= 0: break # Exit if the budget has already been exceeded. # Wait for a notification, but no longer than the remaining budget or 1 second. await asyncio.wait_for( wait_condition.wait(), timeout=min(1, remaining_budget) ) lib_logger.info("Notified that a key was released. Re-evaluating...") except asyncio.TimeoutError: # This is not an error, just a timeout for the wait. The main loop will re-evaluate. lib_logger.info("Wait timed out. Re-evaluating for any available key.") # If the loop exits, it means the deadline was exceeded. raise NoAvailableKeysError( f"Could not acquire a key for model {model} within the global time budget." ) async def release_key(self, key: str, model: str): """Releases a key's lock for a specific model and notifies waiting tasks.""" if key not in self.key_states: return state = self.key_states[key] async with state["lock"]: if model in state["models_in_use"]: state["models_in_use"][model] -= 1 remaining = state["models_in_use"][model] if remaining <= 0: del state["models_in_use"][model] # Clean up when count reaches 0 lib_logger.info( f"Released credential {mask_credential(key)} from model {model} " f"(remaining concurrent: {max(0, remaining)})" ) else: lib_logger.warning( f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." ) # Notify all tasks waiting on this key's condition async with state["condition"]: state["condition"].notify_all() async def record_success( self, key: str, model: str, completion_response: Optional[litellm.ModelResponse] = None, ): """ Records a successful API call, resetting failure counters. It safely handles cases where token usage data is not available. Supports two modes based on provider configuration: - per_model: Each model has its own window_start_ts and stats in key_data["models"] - credential: Legacy mode with key_data["daily"]["models"] """ await self._lazy_init() async with self._data_lock: now_ts = time.time() today_utc_str = datetime.now(timezone.utc).date().isoformat() reset_config = self._get_usage_reset_config(key) reset_mode = ( reset_config.get("mode", "credential") if reset_config else "credential" ) if reset_mode == "per_model": # New per-model structure key_data = self._usage_data.setdefault( key, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Ensure models dict exists if "models" not in key_data: key_data["models"] = {} # Get or create per-model data with window tracking model_data = key_data["models"].setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) # Start window on first request for this model if model_data.get("window_start_ts") is None: model_data["window_start_ts"] = now_ts # Set expected quota reset time from provider config window_seconds = ( reset_config.get("window_seconds", 0) if reset_config else 0 ) if window_seconds > 0: model_data["quota_reset_ts"] = now_ts + window_seconds window_hours = window_seconds / 3600 if window_seconds else 0 lib_logger.info( f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" ) # Record stats model_data["success_count"] += 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 # Sync request_count across quota group (for providers with shared quota pools) new_request_count = model_data["request_count"] group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = key_data["models"].setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) other_model_data["request_count"] = new_request_count # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: other_model_data["quota_max_requests"] = max_req other_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) # Update quota_display if max_requests is set (Antigravity-specific) max_req = model_data.get("quota_max_requests") if max_req: model_data["quota_display"] = ( f"{model_data['request_count']}/{max_req}" ) usage_data_ref = model_data # For token/cost recording below else: # Legacy credential-level structure key_data = self._usage_data.setdefault( key, { "daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) if "last_daily_reset" not in key_data: key_data["last_daily_reset"] = today_utc_str # Get or create model data in daily structure usage_data_ref = key_data["daily"]["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) usage_data_ref["success_count"] += 1 # Reset failures for this model model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) model_failures["consecutive_failures"] = 0 # Clear transient cooldown on success (but NOT quota_reset_ts) if model in key_data.get("model_cooldowns", {}): del key_data["model_cooldowns"][model] # Record token and cost usage if ( completion_response and hasattr(completion_response, "usage") and completion_response.usage ): usage = completion_response.usage usage_data_ref["prompt_tokens"] += usage.prompt_tokens usage_data_ref["completion_tokens"] += getattr( usage, "completion_tokens", 0 ) lib_logger.info( f"Recorded usage from response object for key {mask_credential(key)}" ) try: provider_name = model.split("/")[0] provider_instance = self._get_provider_instance(provider_name) if provider_instance and getattr( provider_instance, "skip_cost_calculation", False ): lib_logger.debug( f"Skipping cost calculation for provider '{provider_name}' (custom provider)." ) else: if isinstance(completion_response, litellm.EmbeddingResponse): model_info = litellm.get_model_info(model) input_cost = model_info.get("input_cost_per_token") if input_cost: cost = ( completion_response.usage.prompt_tokens * input_cost ) else: cost = None else: cost = litellm.completion_cost( completion_response=completion_response, model=model ) if cost is not None: usage_data_ref["approx_cost"] += cost except Exception as e: lib_logger.warning( f"Could not calculate cost for model {model}: {e}" ) elif isinstance(completion_response, asyncio.Future) or hasattr( completion_response, "__aiter__" ): pass # Stream - usage recorded from chunks else: lib_logger.warning( f"No usage data found in completion response for model {model}. Recording success without token count." ) key_data["last_used_ts"] = now_ts await self._save_usage() async def record_failure( self, key: str, model: str, classified_error: ClassifiedError, increment_consecutive_failures: bool = True, ): """Records a failure and applies cooldowns based on error type. Distinguishes between: - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time - rate_limit: Short transient cooldown (just wait and retry) Only sets model_cooldowns - does NOT affect stats reset timing Args: key: The API key or credential identifier model: The model name classified_error: The classified error object increment_consecutive_failures: Whether to increment the failure counter. Set to False for provider-level errors that shouldn't count against the key. """ await self._lazy_init() async with self._data_lock: now_ts = time.time() today_utc_str = datetime.now(timezone.utc).date().isoformat() reset_config = self._get_usage_reset_config(key) reset_mode = ( reset_config.get("mode", "credential") if reset_config else "credential" ) # Initialize key data with appropriate structure if reset_mode == "per_model": key_data = self._usage_data.setdefault( key, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) else: key_data = self._usage_data.setdefault( key, { "daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Provider-level errors (transient issues) should not count against the key provider_level_errors = {"server_error", "api_connection"} # Determine if we should increment the failure counter should_increment = ( increment_consecutive_failures and classified_error.error_type not in provider_level_errors ) # Calculate cooldown duration based on error type cooldown_seconds = None model_cooldowns = key_data.setdefault("model_cooldowns", {}) if classified_error.error_type == "quota_exceeded": # Quota exhausted - use authoritative reset timestamp if available quota_reset_ts = classified_error.quota_reset_timestamp cooldown_seconds = classified_error.retry_after or 60 if quota_reset_ts and reset_mode == "per_model": # Set quota_reset_ts on model - this becomes authoritative stats reset time models_data = key_data.setdefault("models", {}) model_data = models_data.setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) model_data["quota_reset_ts"] = quota_reset_ts # Track failure for quota estimation (request still consumes quota) model_data["failure_count"] = model_data.get("failure_count", 0) + 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 new_request_count = model_data["request_count"] # Apply to all models in the same quota group group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: group_model_data = models_data.setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) group_model_data["quota_reset_ts"] = quota_reset_ts # Sync request_count across quota group group_model_data["request_count"] = new_request_count # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: group_model_data["quota_max_requests"] = max_req group_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) # Also set transient cooldown for selection logic model_cooldowns[grouped_model] = quota_reset_ts reset_dt = datetime.fromtimestamp( quota_reset_ts, tz=timezone.utc ) lib_logger.info( f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" ) else: reset_dt = datetime.fromtimestamp( quota_reset_ts, tz=timezone.utc ) hours = (quota_reset_ts - now_ts) / 3600 lib_logger.info( f"Quota exhausted for model {model} on {mask_credential(key)}. " f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" ) # Set transient cooldown for selection logic model_cooldowns[model] = quota_reset_ts else: # No authoritative timestamp or legacy mode - just use retry_after model_cooldowns[model] = now_ts + cooldown_seconds hours = cooldown_seconds / 3600 lib_logger.info( f"Quota exhausted on {mask_credential(key)} for model {model}. " f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" ) elif classified_error.error_type == "rate_limit": # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) cooldown_seconds = classified_error.retry_after or 60 model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.info( f"Rate limit on {mask_credential(key)} for model {model}. " f"Transient cooldown: {cooldown_seconds}s" ) elif classified_error.error_type == "authentication": # Apply a 5-minute key-level lockout for auth errors key_data["key_cooldown_until"] = now_ts + 300 cooldown_seconds = 300 model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.warning( f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." ) # If we should increment failures, calculate escalating backoff if should_increment: failures_data = key_data.setdefault("failures", {}) model_failures = failures_data.setdefault( model, {"consecutive_failures": 0} ) model_failures["consecutive_failures"] += 1 count = model_failures["consecutive_failures"] # If cooldown wasn't set by specific error type, use escalating backoff if cooldown_seconds is None: backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120} cooldown_seconds = backoff_tiers.get(count, 7200) model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.warning( f"Failure #{count} for key {mask_credential(key)} with model {model}. " f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" ) else: # Provider-level errors: apply short cooldown but don't count against key if cooldown_seconds is None: cooldown_seconds = 30 model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.info( f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" ) # Check for key-level lockout condition await self._check_key_lockout(key, key_data) # Track failure count for quota estimation (all failures consume quota) # This is separate from consecutive_failures which is for backoff logic if reset_mode == "per_model": models_data = key_data.setdefault("models", {}) model_data = models_data.setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) # Only increment if not already incremented in quota_exceeded branch if classified_error.error_type != "quota_exceeded": model_data["failure_count"] = model_data.get("failure_count", 0) + 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 # Sync request_count across quota group new_request_count = model_data["request_count"] group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = models_data.setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) other_model_data["request_count"] = new_request_count # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: other_model_data["quota_max_requests"] = max_req other_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) key_data["last_failure"] = { "timestamp": now_ts, "model": model, "error": str(classified_error.original_exception), } await self._save_usage() async def update_quota_baseline( self, credential: str, model: str, remaining_fraction: float, max_requests: Optional[int] = None, ) -> None: """ Update quota baseline data for a credential/model after fetching from API. This stores the current quota state as a baseline, which is used to estimate remaining quota based on subsequent request counts. Args: credential: Credential identifier (file path or env:// URI) model: Model name (with or without provider prefix) remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) """ await self._lazy_init() async with self._data_lock: now_ts = time.time() # Get or create key data structure key_data = self._usage_data.setdefault( credential, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Ensure models dict exists if "models" not in key_data: key_data["models"] = {} # Get or create per-model data model_data = key_data["models"].setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, "baseline_remaining_fraction": None, "baseline_fetched_at": None, "requests_at_baseline": None, }, ) # Calculate actual used requests from API's remaining fraction # The API is authoritative - sync our local count to match reality if max_requests is not None: used_requests = int((1.0 - remaining_fraction) * max_requests) else: # Estimate max_requests from provider's quota cost # This matches how get_max_requests_for_model() calculates it provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr( plugin_instance, "get_max_requests_for_model" ): # Get tier from provider's cache tier = getattr(plugin_instance, "project_tier_cache", {}).get( credential, "standard-tier" ) # Strip provider prefix from model if present clean_model = model.split("/")[-1] if "/" in model else model max_requests = plugin_instance.get_max_requests_for_model( clean_model, tier ) used_requests = int((1.0 - remaining_fraction) * max_requests) else: # Fallback: keep existing count if we can't calculate used_requests = model_data.get("request_count", 0) max_requests = model_data.get("quota_max_requests") # Sync local request count to API's authoritative value model_data["request_count"] = used_requests model_data["requests_at_baseline"] = used_requests # Update baseline fields model_data["baseline_remaining_fraction"] = remaining_fraction model_data["baseline_fetched_at"] = now_ts # Update max_requests and quota_display if max_requests is not None: model_data["quota_max_requests"] = max_requests model_data["quota_display"] = f"{used_requests}/{max_requests}" # Sync request_count and quota_max_requests across quota group group = self._get_model_quota_group(credential, model) if group: grouped_models = self._get_grouped_models(credential, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = key_data["models"].setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) other_model_data["request_count"] = used_requests if max_requests is not None: other_model_data["quota_max_requests"] = max_requests other_model_data["quota_display"] = ( f"{used_requests}/{max_requests}" ) lib_logger.debug( f"Updated quota baseline for {mask_credential(credential)} model={model}: " f"remaining={remaining_fraction:.2%}, synced_request_count={used_requests}" ) await self._save_usage() async def _check_key_lockout(self, key: str, key_data: Dict): """ Checks if a key should be locked out due to multiple model failures. NOTE: This check is currently disabled. The original logic counted individual models in long-term lockout, but this caused issues with quota groups - when a single quota group (e.g., "claude" with 5 models) was exhausted, it would count as 5 lockouts and trigger key-level lockout, blocking other quota groups (like gemini) that were still available. The per-model and per-group cooldowns already handle quota exhaustion properly. """ # Disabled - see docstring above pass async def get_stats_for_endpoint( self, provider_filter: Optional[str] = None, include_global: bool = True, ) -> Dict[str, Any]: """ Get usage stats formatted for the /v1/quota-stats endpoint. Aggregates data from key_usage.json grouped by provider. Includes both current period stats and global (lifetime) stats. Args: provider_filter: If provided, only return stats for this provider include_global: If True, include global/lifetime stats alongside current Returns: { "providers": { "provider_name": { "credential_count": int, "active_count": int, "on_cooldown_count": int, "total_requests": int, "tokens": { "input_cached": int, "input_uncached": int, "input_cache_pct": float, "output": int }, "approx_cost": float | None, "credentials": [...], "global": {...} # If include_global is True } }, "summary": {...}, "global_summary": {...}, # If include_global is True "timestamp": float } """ await self._lazy_init() now_ts = time.time() providers: Dict[str, Dict[str, Any]] = {} # Track global stats separately global_providers: Dict[str, Dict[str, Any]] = {} async with self._data_lock: if not self._usage_data: return { "providers": {}, "summary": { "total_providers": 0, "total_credentials": 0, "active_credentials": 0, "exhausted_credentials": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_total_cost": 0.0, }, "global_summary": { "total_providers": 0, "total_credentials": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_total_cost": 0.0, }, "data_source": "cache", "timestamp": now_ts, } for credential, cred_data in self._usage_data.items(): # Extract provider from credential path provider = self._get_provider_from_credential(credential) if not provider: continue # Apply filter if specified if provider_filter and provider != provider_filter: continue # Initialize provider entry if provider not in providers: providers[provider] = { "credential_count": 0, "active_count": 0, "on_cooldown_count": 0, "exhausted_count": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_cost": 0.0, "credentials": [], } global_providers[provider] = { "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_cost": 0.0, } prov_stats = providers[provider] prov_stats["credential_count"] += 1 # Determine credential status and cooldowns key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 model_cooldowns = cred_data.get("model_cooldowns", {}) # Build active cooldowns with remaining time active_cooldowns = {} for model, cooldown_ts in model_cooldowns.items(): if cooldown_ts > now_ts: remaining_seconds = int(cooldown_ts - now_ts) active_cooldowns[model] = { "until_ts": cooldown_ts, "remaining_seconds": remaining_seconds, } key_cooldown_remaining = None if key_cooldown > now_ts: key_cooldown_remaining = int(key_cooldown - now_ts) has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 # Check if exhausted (all quota groups exhausted for Antigravity) is_exhausted = False models_data = cred_data.get("models", {}) if models_data: # Check if any model has remaining quota all_exhausted = True for model_stats in models_data.values(): if isinstance(model_stats, dict): baseline = model_stats.get("baseline_remaining_fraction") if baseline is None or baseline > 0: all_exhausted = False break if all_exhausted and len(models_data) > 0: is_exhausted = True if is_exhausted: prov_stats["exhausted_count"] += 1 status = "exhausted" elif has_active_cooldown: prov_stats["on_cooldown_count"] += 1 status = "cooldown" else: prov_stats["active_count"] += 1 status = "active" # Aggregate token stats (current period) cred_tokens = { "input_cached": 0, "input_uncached": 0, "output": 0, } cred_requests = 0 cred_cost = 0.0 # Aggregate global token stats cred_global_tokens = { "input_cached": 0, "input_uncached": 0, "output": 0, } cred_global_requests = 0 cred_global_cost = 0.0 # Handle per-model structure (current period) if models_data: for model_name, model_stats in models_data.items(): if not isinstance(model_stats, dict): continue # Prefer request_count if available and non-zero, else fall back to success+failure req_count = model_stats.get("request_count", 0) if req_count > 0: cred_requests += req_count else: cred_requests += model_stats.get("success_count", 0) cred_requests += model_stats.get("failure_count", 0) # Token stats - track cached separately cred_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_tokens["input_uncached"] += model_stats.get( "prompt_tokens", 0 ) cred_tokens["output"] += model_stats.get("completion_tokens", 0) cred_cost += model_stats.get("approx_cost", 0.0) # Handle legacy daily structure daily_data = cred_data.get("daily", {}) daily_models = daily_data.get("models", {}) for model_name, model_stats in daily_models.items(): if not isinstance(model_stats, dict): continue cred_requests += model_stats.get("success_count", 0) cred_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) cred_tokens["output"] += model_stats.get("completion_tokens", 0) cred_cost += model_stats.get("approx_cost", 0.0) # Handle global stats global_data = cred_data.get("global", {}) global_models = global_data.get("models", {}) for model_name, model_stats in global_models.items(): if not isinstance(model_stats, dict): continue cred_global_requests += model_stats.get("success_count", 0) cred_global_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_global_tokens["input_uncached"] += model_stats.get( "prompt_tokens", 0 ) cred_global_tokens["output"] += model_stats.get( "completion_tokens", 0 ) cred_global_cost += model_stats.get("approx_cost", 0.0) # Add current period stats to global totals cred_global_requests += cred_requests cred_global_tokens["input_cached"] += cred_tokens["input_cached"] cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] cred_global_tokens["output"] += cred_tokens["output"] cred_global_cost += cred_cost # Build credential entry # Mask credential identifier for display if credential.startswith("env://"): identifier = credential else: identifier = Path(credential).name cred_entry = { "identifier": identifier, "full_path": credential, "status": status, "last_used_ts": cred_data.get("last_used_ts"), "requests": cred_requests, "tokens": cred_tokens, "approx_cost": cred_cost if cred_cost > 0 else None, } # Add cooldown info if key_cooldown_remaining is not None: cred_entry["key_cooldown_remaining"] = key_cooldown_remaining if active_cooldowns: cred_entry["model_cooldowns"] = active_cooldowns # Add global stats for this credential if include_global: # Calculate global cache percentage global_total_input = ( cred_global_tokens["input_cached"] + cred_global_tokens["input_uncached"] ) global_cache_pct = ( round( cred_global_tokens["input_cached"] / global_total_input * 100, 1, ) if global_total_input > 0 else 0 ) cred_entry["global"] = { "requests": cred_global_requests, "tokens": { "input_cached": cred_global_tokens["input_cached"], "input_uncached": cred_global_tokens["input_uncached"], "input_cache_pct": global_cache_pct, "output": cred_global_tokens["output"], }, "approx_cost": cred_global_cost if cred_global_cost > 0 else None, } # Add model-specific data for providers with per-model tracking if models_data: cred_entry["models"] = {} for model_name, model_stats in models_data.items(): if not isinstance(model_stats, dict): continue cred_entry["models"][model_name] = { "requests": model_stats.get("success_count", 0) + model_stats.get("failure_count", 0), "request_count": model_stats.get("request_count", 0), "success_count": model_stats.get("success_count", 0), "failure_count": model_stats.get("failure_count", 0), "prompt_tokens": model_stats.get("prompt_tokens", 0), "prompt_tokens_cached": model_stats.get( "prompt_tokens_cached", 0 ), "completion_tokens": model_stats.get( "completion_tokens", 0 ), "approx_cost": model_stats.get("approx_cost", 0.0), "window_start_ts": model_stats.get("window_start_ts"), "quota_reset_ts": model_stats.get("quota_reset_ts"), # Quota baseline fields (Antigravity-specific) "baseline_remaining_fraction": model_stats.get( "baseline_remaining_fraction" ), "baseline_fetched_at": model_stats.get( "baseline_fetched_at" ), "quota_max_requests": model_stats.get("quota_max_requests"), "quota_display": model_stats.get("quota_display"), } prov_stats["credentials"].append(cred_entry) # Aggregate to provider totals (current period) prov_stats["total_requests"] += cred_requests prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] prov_stats["tokens"]["output"] += cred_tokens["output"] if cred_cost > 0: prov_stats["approx_cost"] += cred_cost # Aggregate to global provider totals global_providers[provider]["total_requests"] += cred_global_requests global_providers[provider]["tokens"]["input_cached"] += ( cred_global_tokens["input_cached"] ) global_providers[provider]["tokens"]["input_uncached"] += ( cred_global_tokens["input_uncached"] ) global_providers[provider]["tokens"]["output"] += cred_global_tokens[ "output" ] global_providers[provider]["approx_cost"] += cred_global_cost # Calculate cache percentages for each provider for provider, prov_stats in providers.items(): total_input = ( prov_stats["tokens"]["input_cached"] + prov_stats["tokens"]["input_uncached"] ) if total_input > 0: prov_stats["tokens"]["input_cache_pct"] = round( prov_stats["tokens"]["input_cached"] / total_input * 100, 1 ) # Set cost to None if 0 if prov_stats["approx_cost"] == 0: prov_stats["approx_cost"] = None # Calculate global cache percentages if include_global and provider in global_providers: gp = global_providers[provider] global_total = ( gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] ) if global_total > 0: gp["tokens"]["input_cache_pct"] = round( gp["tokens"]["input_cached"] / global_total * 100, 1 ) if gp["approx_cost"] == 0: gp["approx_cost"] = None prov_stats["global"] = gp # Build summary (current period) total_creds = sum(p["credential_count"] for p in providers.values()) active_creds = sum(p["active_count"] for p in providers.values()) exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) total_requests = sum(p["total_requests"] for p in providers.values()) total_input_cached = sum( p["tokens"]["input_cached"] for p in providers.values() ) total_input_uncached = sum( p["tokens"]["input_uncached"] for p in providers.values() ) total_output = sum(p["tokens"]["output"] for p in providers.values()) total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) total_input = total_input_cached + total_input_uncached input_cache_pct = ( round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 ) result = { "providers": providers, "summary": { "total_providers": len(providers), "total_credentials": total_creds, "active_credentials": active_creds, "exhausted_credentials": exhausted_creds, "total_requests": total_requests, "tokens": { "input_cached": total_input_cached, "input_uncached": total_input_uncached, "input_cache_pct": input_cache_pct, "output": total_output, }, "approx_total_cost": total_cost if total_cost > 0 else None, }, "data_source": "cache", "timestamp": now_ts, } # Build global summary if include_global: global_total_requests = sum( gp["total_requests"] for gp in global_providers.values() ) global_total_input_cached = sum( gp["tokens"]["input_cached"] for gp in global_providers.values() ) global_total_input_uncached = sum( gp["tokens"]["input_uncached"] for gp in global_providers.values() ) global_total_output = sum( gp["tokens"]["output"] for gp in global_providers.values() ) global_total_cost = sum( gp["approx_cost"] or 0 for gp in global_providers.values() ) global_total_input = global_total_input_cached + global_total_input_uncached global_input_cache_pct = ( round(global_total_input_cached / global_total_input * 100, 1) if global_total_input > 0 else 0 ) result["global_summary"] = { "total_providers": len(global_providers), "total_credentials": total_creds, "total_requests": global_total_requests, "tokens": { "input_cached": global_total_input_cached, "input_uncached": global_total_input_uncached, "input_cache_pct": global_input_cache_pct, "output": global_total_output, }, "approx_total_cost": global_total_cost if global_total_cost > 0 else None, } return result async def reload_from_disk(self) -> None: """ Force reload usage data from disk. Useful when another process may have updated the file. """ async with self._init_lock: self._initialized.clear() await self._load_usage() await self._reset_daily_stats_if_needed() self._initialized.set()