# SPDX-License-Identifier: LGPL-3.0-only # Copyright (c) 2026 Mirrowel import json import os import time import logging import asyncio import random from datetime import date, datetime, timezone, time as dt_time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union import aiofiles import litellm from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential from .providers import PROVIDER_PLUGINS from .utils.resilient_io import ResilientStateWriter from .utils.paths import get_data_file from .config import ( DEFAULT_FAIR_CYCLE_DURATION, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD, DEFAULT_CUSTOM_CAP_COOLDOWN_MODE, DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE, COOLDOWN_BACKOFF_TIERS, COOLDOWN_BACKOFF_MAX, COOLDOWN_AUTH_ERROR, COOLDOWN_TRANSIENT_ERROR, COOLDOWN_RATE_LIMIT_DEFAULT, ) lib_logger = logging.getLogger("rotator_library") lib_logger.propagate = False if not lib_logger.handlers: lib_logger.addHandler(logging.NullHandler()) class UsageManager: """ Manages usage statistics and cooldowns for API keys with asyncio-safe locking, asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation. The credential rotation strategy can be configured via the `rotation_tolerance` parameter: - **tolerance = 0.0**: Deterministic least-used selection. The credential with the lowest usage count is always selected. This provides predictable, perfectly balanced load distribution but may be vulnerable to fingerprinting. - **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected randomly with weights biased toward less-used ones. Credentials within 2 uses of the maximum can still be selected with reasonable probability. This provides security through unpredictability while maintaining good load balance. - **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant selection probability. Useful for stress testing or maximum unpredictability, but may result in less balanced load distribution. The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1` This ensures lower-usage credentials are preferred while tolerance controls how much randomness is introduced into the selection process. Additionally, providers can specify a rotation mode: - "balanced" (default): Rotate credentials to distribute load evenly - "sequential": Use one credential until exhausted (preserves caching) """ def __init__( self, file_path: Optional[Union[str, Path]] = None, daily_reset_time_utc: Optional[str] = "03:00", rotation_tolerance: float = 0.0, provider_rotation_modes: Optional[Dict[str, str]] = None, provider_plugins: Optional[Dict[str, Any]] = None, priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None, priority_multipliers_by_mode: Optional[ Dict[str, Dict[str, Dict[int, int]]] ] = None, sequential_fallback_multipliers: Optional[Dict[str, int]] = None, fair_cycle_enabled: Optional[Dict[str, bool]] = None, fair_cycle_tracking_mode: Optional[Dict[str, str]] = None, fair_cycle_cross_tier: Optional[Dict[str, bool]] = None, fair_cycle_duration: Optional[Dict[str, int]] = None, exhaustion_cooldown_threshold: Optional[Dict[str, int]] = None, custom_caps: Optional[ Dict[str, Dict[Union[int, Tuple[int, ...], str], Dict[str, Dict[str, Any]]]] ] = None, ): """ Initialize the UsageManager. Args: file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json"). Can be absolute Path, relative Path, or string. daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format) rotation_tolerance: Tolerance for weighted random credential rotation. - 0.0: Deterministic, least-used credential always selected - tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - 5.0+: High randomness, more unpredictable selection patterns provider_rotation_modes: Dict mapping provider names to rotation modes. - "balanced": Rotate credentials to distribute load evenly (default) - "sequential": Use one credential until exhausted (preserves caching) provider_plugins: Dict mapping provider names to provider plugin instances. Used for per-provider usage reset configuration (window durations, field names). priority_multipliers: Dict mapping provider -> priority -> multiplier. Universal multipliers that apply regardless of rotation mode. Example: {"antigravity": {1: 5, 2: 3}} priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier. Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}} sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier. Used in sequential mode when priority not in priority_multipliers. Example: {"antigravity": 2} fair_cycle_enabled: Dict mapping provider -> bool to enable fair cycle rotation. When enabled, credentials must all exhaust before any can be reused. Default: enabled for sequential mode only. fair_cycle_tracking_mode: Dict mapping provider -> tracking mode. - "model_group": Track per quota group or model (default) - "credential": Track per credential globally fair_cycle_cross_tier: Dict mapping provider -> bool for cross-tier tracking. - False: Each tier cycles independently (default) - True: All credentials must exhaust regardless of tier fair_cycle_duration: Dict mapping provider -> cycle duration in seconds. Default: 86400 (24 hours) exhaustion_cooldown_threshold: Dict mapping provider -> threshold in seconds. A cooldown must exceed this to qualify as "exhausted". Default: 300 (5 min) custom_caps: Dict mapping provider -> tier -> model/group -> cap config. Allows setting custom usage limits per tier, per model or quota group. See ProviderInterface.default_custom_caps for format details. """ # Resolve file_path - use default if not provided if file_path is None: self.file_path = str(get_data_file("key_usage.json")) elif isinstance(file_path, Path): self.file_path = str(file_path) else: # String path - could be relative or absolute self.file_path = file_path self.rotation_tolerance = rotation_tolerance self.provider_rotation_modes = provider_rotation_modes or {} self.provider_plugins = provider_plugins or PROVIDER_PLUGINS self.priority_multipliers = priority_multipliers or {} self.priority_multipliers_by_mode = priority_multipliers_by_mode or {} self.sequential_fallback_multipliers = sequential_fallback_multipliers or {} self._provider_instances: Dict[str, Any] = {} # Cache for provider instances self.key_states: Dict[str, Dict[str, Any]] = {} # Fair cycle rotation configuration self.fair_cycle_enabled = fair_cycle_enabled or {} self.fair_cycle_tracking_mode = fair_cycle_tracking_mode or {} self.fair_cycle_cross_tier = fair_cycle_cross_tier or {} self.fair_cycle_duration = fair_cycle_duration or {} self.exhaustion_cooldown_threshold = exhaustion_cooldown_threshold or {} self.custom_caps = custom_caps or {} # In-memory cycle state: {provider: {tier_key: {tracking_key: {"cycle_started_at": float, "exhausted": Set[str]}}}} self._cycle_exhausted: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {} self._data_lock = asyncio.Lock() self._usage_data: Optional[Dict] = None self._initialized = asyncio.Event() self._init_lock = asyncio.Lock() self._timeout_lock = asyncio.Lock() self._claimed_on_timeout: Set[str] = set() # Resilient writer for usage data persistence self._state_writer = ResilientStateWriter(file_path, lib_logger) if daily_reset_time_utc: hour, minute = map(int, daily_reset_time_utc.split(":")) self.daily_reset_time_utc = dt_time( hour=hour, minute=minute, tzinfo=timezone.utc ) else: self.daily_reset_time_utc = None def _get_rotation_mode(self, provider: str) -> str: """ Get the rotation mode for a provider. Args: provider: Provider name (e.g., "antigravity", "gemini_cli") Returns: "balanced" or "sequential" """ return self.provider_rotation_modes.get(provider, "balanced") # ========================================================================= # FAIR CYCLE ROTATION HELPERS # ========================================================================= def _is_fair_cycle_enabled(self, provider: str, rotation_mode: str) -> bool: """ Check if fair cycle rotation is enabled for a provider. Args: provider: Provider name rotation_mode: Current rotation mode ("balanced" or "sequential") Returns: True if fair cycle is enabled """ # Check provider-specific setting first if provider in self.fair_cycle_enabled: return self.fair_cycle_enabled[provider] # Default: enabled only for sequential mode return rotation_mode == "sequential" def _get_fair_cycle_tracking_mode(self, provider: str) -> str: """ Get fair cycle tracking mode for a provider. Returns: "model_group" or "credential" """ return self.fair_cycle_tracking_mode.get(provider, "model_group") def _is_fair_cycle_cross_tier(self, provider: str) -> bool: """ Check if fair cycle tracks across all tiers (ignoring priority boundaries). Returns: True if cross-tier tracking is enabled """ return self.fair_cycle_cross_tier.get(provider, False) def _get_fair_cycle_duration(self, provider: str) -> int: """ Get fair cycle duration in seconds for a provider. Returns: Duration in seconds (default 86400 = 24 hours) """ return self.fair_cycle_duration.get(provider, DEFAULT_FAIR_CYCLE_DURATION) def _get_exhaustion_cooldown_threshold(self, provider: str) -> int: """ Get exhaustion cooldown threshold in seconds for a provider. A cooldown must exceed this duration to qualify as "exhausted" for fair cycle. Returns: Threshold in seconds (default 300 = 5 minutes) """ return self.exhaustion_cooldown_threshold.get( provider, DEFAULT_EXHAUSTION_COOLDOWN_THRESHOLD ) # ========================================================================= # CUSTOM CAPS HELPERS # ========================================================================= def _get_custom_cap_config( self, provider: str, tier_priority: int, model: str, ) -> Optional[Dict[str, Any]]: """ Get custom cap config for a provider/tier/model combination. Resolution order: 1. tier + model (exact match) 2. tier + group (model's quota group) 3. "default" + model 4. "default" + group Args: provider: Provider name tier_priority: Credential's priority level model: Model name (with provider prefix) Returns: Cap config dict or None if no custom cap applies """ provider_caps = self.custom_caps.get(provider) if not provider_caps: return None # Strip provider prefix from model clean_model = model.split("/")[-1] if "/" in model else model # Get quota group for this model group = self._get_model_quota_group_by_provider(provider, model) # Try to find matching tier config tier_config = None default_config = None for tier_key, models_config in provider_caps.items(): if tier_key == "default": default_config = models_config continue # Check if this tier_key matches our priority if isinstance(tier_key, int) and tier_key == tier_priority: tier_config = models_config break elif isinstance(tier_key, tuple) and tier_priority in tier_key: tier_config = models_config break # Resolution order for tier config if tier_config: # Try model first if clean_model in tier_config: return tier_config[clean_model] # Try group if group and group in tier_config: return tier_config[group] # Resolution order for default config if default_config: # Try model first if clean_model in default_config: return default_config[clean_model] # Try group if group and group in default_config: return default_config[group] return None def _get_model_quota_group_by_provider( self, provider: str, model: str ) -> Optional[str]: """ Get quota group for a model using provider name instead of credential. Args: provider: Provider name model: Model name Returns: Group name or None """ plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): return plugin_instance.get_model_quota_group(model) return None def _resolve_custom_cap_max( self, provider: str, model: str, cap_config: Dict[str, Any], actual_max: Optional[int], ) -> Optional[int]: """ Resolve custom cap max_requests value, handling percentages and clamping. Args: provider: Provider name model: Model name (for logging) cap_config: Custom cap configuration actual_max: Actual API max requests (may be None if unknown) Returns: Resolved cap value (clamped), or None if can't be calculated """ max_requests = cap_config.get("max_requests") if max_requests is None: return None # Handle percentage if isinstance(max_requests, str) and max_requests.endswith("%"): if actual_max is None: lib_logger.warning( f"Custom cap '{max_requests}' for {provider}/{model} requires known max_requests. " f"Skipping until quota baseline is fetched. Use absolute value for immediate enforcement." ) return None try: percentage = float(max_requests.rstrip("%")) / 100.0 calculated = int(actual_max * percentage) except ValueError: lib_logger.warning( f"Invalid percentage cap '{max_requests}' for {provider}/{model}" ) return None else: # Absolute value try: calculated = int(max_requests) except (ValueError, TypeError): lib_logger.warning( f"Invalid cap value '{max_requests}' for {provider}/{model}" ) return None # Clamp to actual max (can only be MORE restrictive) if actual_max is not None: return min(calculated, actual_max) return calculated def _calculate_custom_cooldown_until( self, cap_config: Dict[str, Any], window_start_ts: Optional[float], natural_reset_ts: Optional[float], ) -> Optional[float]: """ Calculate when custom cap cooldown should end, clamped to natural reset. Args: cap_config: Custom cap configuration window_start_ts: When first request was made (for fixed mode) natural_reset_ts: Natural quota reset timestamp Returns: Cooldown end timestamp (clamped), or None if can't calculate """ mode = cap_config.get("cooldown_mode", DEFAULT_CUSTOM_CAP_COOLDOWN_MODE) value = cap_config.get("cooldown_value", DEFAULT_CUSTOM_CAP_COOLDOWN_VALUE) if mode == "quota_reset": calculated = natural_reset_ts elif mode == "offset": if natural_reset_ts is None: return None calculated = natural_reset_ts + value elif mode == "fixed": if window_start_ts is None: return None calculated = window_start_ts + value else: lib_logger.warning(f"Unknown cooldown_mode '{mode}', using quota_reset") calculated = natural_reset_ts if calculated is None: return None # Clamp to natural reset (can only be MORE restrictive = longer cooldown) if natural_reset_ts is not None: return max(calculated, natural_reset_ts) return calculated def _check_and_apply_custom_cap( self, credential: str, model: str, request_count: int, ) -> bool: """ Check if custom cap is exceeded and apply cooldown if so. This should be called after incrementing request_count in record_success(). Args: credential: Credential identifier model: Model name (with provider prefix) request_count: Current request count for this model Returns: True if cap exceeded and cooldown applied, False otherwise """ provider = self._get_provider_from_credential(credential) if not provider: return False priority = self._get_credential_priority(credential, provider) cap_config = self._get_custom_cap_config(provider, priority, model) if not cap_config: return False # Get model data for actual max and timing info key_data = self._usage_data.get(credential, {}) model_data = key_data.get("models", {}).get(model, {}) actual_max = model_data.get("quota_max_requests") window_start_ts = model_data.get("window_start_ts") natural_reset_ts = model_data.get("quota_reset_ts") # Resolve custom cap max custom_max = self._resolve_custom_cap_max( provider, model, cap_config, actual_max ) if custom_max is None: return False # Check if exceeded if request_count < custom_max: return False # Calculate cooldown end time cooldown_until = self._calculate_custom_cooldown_until( cap_config, window_start_ts, natural_reset_ts ) if cooldown_until is None: # Can't calculate cooldown, use natural reset if available if natural_reset_ts: cooldown_until = natural_reset_ts else: lib_logger.warning( f"Custom cap hit for {mask_credential(credential)}/{model} but can't calculate cooldown. " f"Skipping cooldown application." ) return False now_ts = time.time() # Apply cooldown model_cooldowns = key_data.setdefault("model_cooldowns", {}) model_cooldowns[model] = cooldown_until # Store custom cap info in model data for reference model_data["custom_cap_max"] = custom_max model_data["custom_cap_hit_at"] = now_ts model_data["custom_cap_cooldown_until"] = cooldown_until hours_until = (cooldown_until - now_ts) / 3600 lib_logger.info( f"Custom cap hit: {mask_credential(credential)} reached {request_count}/{custom_max} " f"for {model}. Cooldown for {hours_until:.1f}h" ) # Sync cooldown across quota group group = self._get_model_quota_group(credential, model) if group: grouped_models = self._get_grouped_models(credential, group) for grouped_model in grouped_models: if grouped_model != model: model_cooldowns[grouped_model] = cooldown_until # Check if this should trigger fair cycle exhaustion cooldown_duration = cooldown_until - now_ts threshold = self._get_exhaustion_cooldown_threshold(provider) if cooldown_duration > threshold: rotation_mode = self._get_rotation_mode(provider) if self._is_fair_cycle_enabled(provider, rotation_mode): tier_key = self._get_tier_key(provider, priority) tracking_key = self._get_tracking_key(credential, model, provider) self._mark_credential_exhausted( credential, provider, tier_key, tracking_key ) return True def _get_tier_key(self, provider: str, priority: int) -> str: """ Get the tier key for cycle tracking based on cross_tier setting. Args: provider: Provider name priority: Credential priority level Returns: "__all_tiers__" if cross-tier enabled, else str(priority) """ if self._is_fair_cycle_cross_tier(provider): return "__all_tiers__" return str(priority) def _get_tracking_key(self, credential: str, model: str, provider: str) -> str: """ Get the key for exhaustion tracking based on tracking mode. Args: credential: Credential identifier model: Model name (with provider prefix) provider: Provider name Returns: Tracking key string (quota group name, model name, or "__credential__") """ mode = self._get_fair_cycle_tracking_mode(provider) if mode == "credential": return "__credential__" # model_group mode: use quota group if exists, else model group = self._get_model_quota_group(credential, model) return group if group else model def _get_credential_priority(self, credential: str, provider: str) -> int: """ Get the priority level for a credential. Args: credential: Credential identifier provider: Provider name Returns: Priority level (default 999 if unknown) """ plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_credential_priority"): priority = plugin_instance.get_credential_priority(credential) if priority is not None: return priority return 999 def _get_cycle_data( self, provider: str, tier_key: str, tracking_key: str ) -> Optional[Dict[str, Any]]: """ Get cycle data for a provider/tier/tracking key combination. Returns: Cycle data dict or None if not exists """ return ( self._cycle_exhausted.get(provider, {}).get(tier_key, {}).get(tracking_key) ) def _ensure_cycle_structure( self, provider: str, tier_key: str, tracking_key: str ) -> Dict[str, Any]: """ Ensure the nested cycle structure exists and return the cycle data dict. """ if provider not in self._cycle_exhausted: self._cycle_exhausted[provider] = {} if tier_key not in self._cycle_exhausted[provider]: self._cycle_exhausted[provider][tier_key] = {} if tracking_key not in self._cycle_exhausted[provider][tier_key]: self._cycle_exhausted[provider][tier_key][tracking_key] = { "cycle_started_at": None, "exhausted": set(), } return self._cycle_exhausted[provider][tier_key][tracking_key] def _mark_credential_exhausted( self, credential: str, provider: str, tier_key: str, tracking_key: str, ) -> None: """ Mark a credential as exhausted for fair cycle tracking. Starts the cycle timer on first exhaustion. Skips if credential is already in the exhausted set (prevents duplicate logging). """ cycle_data = self._ensure_cycle_structure(provider, tier_key, tracking_key) # Skip if already exhausted in this cycle (prevents duplicate logging) if credential in cycle_data.get("exhausted", set()): return # Start cycle timer on first exhaustion if cycle_data["cycle_started_at"] is None: cycle_data["cycle_started_at"] = time.time() lib_logger.info( f"Fair cycle started for {provider} tier={tier_key} tracking='{tracking_key}'" ) cycle_data["exhausted"].add(credential) lib_logger.info( f"Fair cycle: marked {mask_credential(credential)} exhausted " f"for {tracking_key} ({len(cycle_data['exhausted'])} total)" ) def _is_credential_exhausted_in_cycle( self, credential: str, provider: str, tier_key: str, tracking_key: str, ) -> bool: """ Check if a credential was exhausted in the current cycle. """ cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) if cycle_data is None: return False return credential in cycle_data.get("exhausted", set()) def _is_cycle_expired( self, provider: str, tier_key: str, tracking_key: str ) -> bool: """ Check if the current cycle has exceeded its duration. """ cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) if cycle_data is None: return False cycle_started = cycle_data.get("cycle_started_at") if cycle_started is None: return False duration = self._get_fair_cycle_duration(provider) return time.time() >= cycle_started + duration def _should_reset_cycle( self, provider: str, tier_key: str, tracking_key: str, all_credentials_in_tier: List[str], available_not_on_cooldown: Optional[List[str]] = None, ) -> bool: """ Check if cycle should reset. Returns True if: 1. Cycle duration has expired, OR 2. No credentials remain available (after cooldown + fair cycle exclusion), OR 3. All credentials in the tier have been marked exhausted (fallback) """ # Check duration first if self._is_cycle_expired(provider, tier_key, tracking_key): return True cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) if cycle_data is None: return False # If available credentials are provided, reset when none remain usable if available_not_on_cooldown is not None: has_available = any( not self._is_credential_exhausted_in_cycle( cred, provider, tier_key, tracking_key ) for cred in available_not_on_cooldown ) if not has_available and len(all_credentials_in_tier) > 0: return True exhausted = cycle_data.get("exhausted", set()) # All must be exhausted (and there must be at least one credential) return ( len(exhausted) >= len(all_credentials_in_tier) and len(all_credentials_in_tier) > 0 ) def _reset_cycle(self, provider: str, tier_key: str, tracking_key: str) -> None: """ Reset exhaustion tracking for a completed cycle. """ cycle_data = self._get_cycle_data(provider, tier_key, tracking_key) if cycle_data: exhausted_count = len(cycle_data.get("exhausted", set())) lib_logger.info( f"Fair cycle complete for {provider} tier={tier_key} " f"tracking='{tracking_key}' - resetting ({exhausted_count} credentials cycled)" ) cycle_data["cycle_started_at"] = None cycle_data["exhausted"] = set() def _get_all_credentials_for_tier_key( self, provider: str, tier_key: str, available_keys: List[str], credential_priorities: Optional[Dict[str, int]], ) -> List[str]: """ Get all credentials that belong to a tier key. Args: provider: Provider name tier_key: Either "__all_tiers__" or str(priority) available_keys: List of available credential identifiers credential_priorities: Dict mapping credentials to priorities Returns: List of credentials belonging to this tier key """ if tier_key == "__all_tiers__": # Cross-tier: all credentials for this provider return list(available_keys) else: # Within-tier: only credentials with matching priority priority = int(tier_key) if credential_priorities: return [ k for k in available_keys if credential_priorities.get(k, 999) == priority ] return list(available_keys) def _count_fair_cycle_excluded( self, provider: str, tier_key: str, tracking_key: str, candidates: List[str], ) -> int: """ Count how many candidates are excluded by fair cycle. Args: provider: Provider name tier_key: Tier key for tracking tracking_key: Model/group tracking key candidates: List of candidate credentials (not on cooldown) Returns: Number of candidates excluded by fair cycle """ count = 0 for cred in candidates: if self._is_credential_exhausted_in_cycle( cred, provider, tier_key, tracking_key ): count += 1 return count def _get_priority_multiplier( self, provider: str, priority: int, rotation_mode: str ) -> int: """ Get the concurrency multiplier for a provider/priority/mode combination. Lookup order: 1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority] 2. Universal tier multiplier: priority_multipliers[provider][priority] 3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider] 4. Global default: 1 (no multiplier effect) Args: provider: Provider name (e.g., "antigravity") priority: Priority level (1 = highest priority) rotation_mode: Current rotation mode ("sequential" or "balanced") Returns: Multiplier value """ provider_lower = provider.lower() # 1. Check mode-specific override if provider_lower in self.priority_multipliers_by_mode: mode_multipliers = self.priority_multipliers_by_mode[provider_lower] if rotation_mode in mode_multipliers: if priority in mode_multipliers[rotation_mode]: return mode_multipliers[rotation_mode][priority] # 2. Check universal tier multiplier if provider_lower in self.priority_multipliers: if priority in self.priority_multipliers[provider_lower]: return self.priority_multipliers[provider_lower][priority] # 3. Sequential fallback (only for sequential mode) if rotation_mode == "sequential": if provider_lower in self.sequential_fallback_multipliers: return self.sequential_fallback_multipliers[provider_lower] # 4. Global default return 1 def _get_provider_from_credential(self, credential: str) -> Optional[str]: """ Extract provider name from credential path or identifier. Supports multiple credential formats: - OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity" - OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli" - OAuth filename only: "antigravity_oauth_1.json" -> "antigravity" - API key style: extracted from model names in usage data (e.g., "firmware/model" -> "firmware") Args: credential: The credential identifier (path or key) Returns: Provider name string or None if cannot be determined """ import re # Pattern: env:// URI format (e.g., "env://antigravity/1" -> "antigravity") if credential.startswith("env://"): parts = credential[6:].split("/") # Remove "env://" prefix if parts and parts[0]: return parts[0].lower() # Malformed env:// URI (empty provider name) lib_logger.warning(f"Malformed env:// credential URI: {credential}") return None # Normalize path separators normalized = credential.replace("\\", "/") # Pattern: path ending with {provider}_oauth_{number}.json match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) if match: return match.group(1).lower() # Pattern: oauth_creds/{provider}_... match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE) if match: return match.group(1).lower() # Pattern: filename only {provider}_oauth_{number}.json (no path) match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE) if match: return match.group(1).lower() # Pattern: API key prefixes for specific providers # These are raw API keys with recognizable prefixes api_key_prefixes = { "sk-nano-": "nanogpt", "sk-or-": "openrouter", "sk-ant-": "anthropic", } for prefix, provider in api_key_prefixes.items(): if credential.startswith(prefix): return provider # Fallback: For raw API keys, extract provider from model names in usage data # This handles providers like firmware, chutes, nanogpt that use credential-level quota if self._usage_data and credential in self._usage_data: cred_data = self._usage_data[credential] # Check "models" section first (for per_model mode and quota tracking) models_data = cred_data.get("models", {}) if models_data: # Get first model name and extract provider prefix first_model = next(iter(models_data.keys()), None) if first_model and "/" in first_model: provider = first_model.split("/")[0].lower() return provider # Fallback to "daily" section (legacy structure) daily_data = cred_data.get("daily", {}) daily_models = daily_data.get("models", {}) if daily_models: # Get first model name and extract provider prefix first_model = next(iter(daily_models.keys()), None) if first_model and "/" in first_model: provider = first_model.split("/")[0].lower() return provider return None def _get_provider_instance(self, provider: str) -> Optional[Any]: """ Get or create a provider plugin instance. Args: provider: The provider name Returns: Provider plugin instance or None """ if not provider: return None plugin_class = self.provider_plugins.get(provider) if not plugin_class: return None # Get or create provider instance from cache if provider not in self._provider_instances: # Instantiate the plugin if it's a class, or use it directly if already an instance if isinstance(plugin_class, type): self._provider_instances[provider] = plugin_class() else: self._provider_instances[provider] = plugin_class return self._provider_instances[provider] def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]: """ Get the usage reset configuration for a credential from its provider plugin. Args: credential: The credential identifier Returns: Configuration dict with window_seconds, field_name, etc. or None to use default daily reset. """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"): return plugin_instance.get_usage_reset_config(credential) return None def _get_reset_mode(self, credential: str) -> str: """ Get the reset mode for a credential: 'credential' or 'per_model'. Args: credential: The credential identifier Returns: "per_model" or "credential" (default) """ config = self._get_usage_reset_config(credential) return config.get("mode", "credential") if config else "credential" def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]: """ Get the quota group for a model, if the provider defines one. Args: credential: The credential identifier model: Model name (with or without provider prefix) Returns: Group name (e.g., "claude") or None if not grouped """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"): return plugin_instance.get_model_quota_group(model) return None def _get_grouped_models(self, credential: str, group: str) -> List[str]: """ Get all model names in a quota group (with provider prefix), normalized. Returns only public-facing model names, deduplicated. Internal variants (e.g., claude-sonnet-4-5-thinking) are normalized to their public name (e.g., claude-sonnet-4.5). Args: credential: The credential identifier group: Group name (e.g., "claude") Returns: List of normalized, deduplicated model names with provider prefix (e.g., ["antigravity/claude-sonnet-4.5", "antigravity/claude-opus-4.5"]) """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"): models = plugin_instance.get_models_in_quota_group(group) # Normalize and deduplicate if hasattr(plugin_instance, "normalize_model_for_tracking"): seen = set() normalized = [] for m in models: prefixed = f"{provider}/{m}" norm = plugin_instance.normalize_model_for_tracking(prefixed) if norm not in seen: seen.add(norm) normalized.append(norm) return normalized # Fallback: just add provider prefix return [f"{provider}/{m}" for m in models] return [] def _get_model_usage_weight(self, credential: str, model: str) -> int: """ Get the usage weight for a model when calculating grouped usage. Args: credential: The credential identifier model: Model name (with or without provider prefix) Returns: Weight multiplier (default 1 if not configured) """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"): return plugin_instance.get_model_usage_weight(model) return 1 def _normalize_model(self, credential: str, model: str) -> str: """ Normalize model name using provider's mapping. Converts internal model names (e.g., claude-sonnet-4-5-thinking) to public-facing names (e.g., claude-sonnet-4.5) for consistent storage. Args: credential: The credential identifier model: Model name (with or without provider prefix) Returns: Normalized model name (provider prefix preserved if present) """ provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "normalize_model_for_tracking"): return plugin_instance.normalize_model_for_tracking(model) return model # Providers where request_count should be used for credential selection # instead of success_count (because failed requests also consume quota) _REQUEST_COUNT_PROVIDERS = {"antigravity", "gemini_cli", "chutes", "nanogpt"} def _get_grouped_usage_count(self, key: str, model: str) -> int: """ Get usage count for credential selection, considering quota groups. For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses request_count instead of success_count since failed requests also consume quota. If the model belongs to a quota group, the request_count is already synced across all models in the group (by record_success/record_failure), so we just read from the requested model directly. Args: key: Credential identifier model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5") Returns: Usage count for the model (synced across group if applicable) """ # Determine usage field based on provider # Some providers (antigravity) count failed requests against quota provider = self._get_provider_from_credential(key) usage_field = ( "request_count" if provider in self._REQUEST_COUNT_PROVIDERS else "success_count" ) # For providers with synced quota groups (antigravity), request_count # is already synced across all models in the group, so just read directly. # For other providers, we still need to sum success_count across group. if provider in self._REQUEST_COUNT_PROVIDERS: # request_count is synced - just read the model's value return self._get_usage_count(key, model, usage_field) # For non-synced providers, check if model is in a quota group and sum group = self._get_model_quota_group(key, model) if group: # Get all models in the group grouped_models = self._get_grouped_models(key, group) # Sum weighted usage across all models in the group total_weighted_usage = 0 for grouped_model in grouped_models: usage = self._get_usage_count(key, grouped_model, usage_field) weight = self._get_model_usage_weight(key, grouped_model) total_weighted_usage += usage * weight return total_weighted_usage # Not grouped - return individual model usage (no weight applied) return self._get_usage_count(key, model, usage_field) def _get_quota_display(self, key: str, model: str) -> str: """ Get a formatted quota display string for logging. For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns: "quota: 170/250 [32%]" format For other providers, returns: "usage: 170" format (no max available) Args: key: Credential identifier model: Model name (with provider prefix) Returns: Formatted string for logging """ provider = self._get_provider_from_credential(key) if provider not in self._REQUEST_COUNT_PROVIDERS: # Non-antigravity: just show usage count usage = self._get_usage_count(key, model, "success_count") return f"usage: {usage}" # Antigravity: show quota display with remaining percentage if self._usage_data is None: return "quota: 0/? [100%]" # Normalize model name for consistent lookup (data is stored under normalized names) model = self._normalize_model(key, model) key_data = self._usage_data.get(key, {}) model_data = key_data.get("models", {}).get(model, {}) request_count = model_data.get("request_count", 0) max_requests = model_data.get("quota_max_requests") if max_requests: remaining = max_requests - request_count remaining_pct = ( int((remaining / max_requests) * 100) if max_requests > 0 else 0 ) return f"quota: {request_count}/{max_requests} [{remaining_pct}%]" else: return f"quota: {request_count}" def _get_usage_field_name(self, credential: str) -> str: """ Get the usage tracking field name for a credential. Returns the provider-specific field name if configured, otherwise falls back to "daily". Args: credential: The credential identifier Returns: Field name string (e.g., "5h_window", "weekly", "daily") """ config = self._get_usage_reset_config(credential) if config and "field_name" in config: return config["field_name"] # Check provider default provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"): return plugin_instance.get_default_usage_field_name() return "daily" def _get_usage_count( self, key: str, model: str, field: str = "success_count" ) -> int: """ Get the current usage count for a model from the appropriate usage structure. Supports both: - New per-model structure: {"models": {"model_name": {"success_count": N, ...}}} - Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}} Args: key: Credential identifier model: Model name field: The field to read for usage count (default: "success_count"). Use "request_count" for providers where failed requests also consume quota (e.g., antigravity). Returns: Usage count for the model in the current window/period """ if self._usage_data is None: return 0 # Normalize model name for consistent lookup (data is stored under normalized names) model = self._normalize_model(key, model) key_data = self._usage_data.get(key, {}) reset_mode = self._get_reset_mode(key) if reset_mode == "per_model": # New per-model structure: key_data["models"][model][field] return key_data.get("models", {}).get(model, {}).get(field, 0) else: # Legacy structure: key_data["daily"]["models"][model][field] return ( key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0) ) # ========================================================================= # TIMESTAMP FORMATTING HELPERS # ========================================================================= def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]: """ Format Unix timestamp as local time string with timezone offset. Args: ts: Unix timestamp or None Returns: Formatted string like "2025-12-07 14:30:17 +0100" or None """ if ts is None: return None try: dt = datetime.fromtimestamp(ts).astimezone() # Local timezone # Use UTC offset for conciseness (works on all platforms) return dt.strftime("%Y-%m-%d %H:%M:%S %z") except (OSError, ValueError, OverflowError): return None def _add_readable_timestamps(self, data: Dict) -> Dict: """ Add human-readable timestamp fields to usage data before saving. Adds 'window_started' and 'quota_resets' fields derived from Unix timestamps for easier debugging and monitoring. Args: data: The usage data dict to enhance Returns: The same dict with readable timestamp fields added """ for key, key_data in data.items(): # Handle per-model structure models = key_data.get("models", {}) for model_name, model_stats in models.items(): if not isinstance(model_stats, dict): continue # Add readable window start time window_start = model_stats.get("window_start_ts") if window_start: model_stats["window_started"] = self._format_timestamp_local( window_start ) elif "window_started" in model_stats: del model_stats["window_started"] # Add readable reset time quota_reset = model_stats.get("quota_reset_ts") if quota_reset: model_stats["quota_resets"] = self._format_timestamp_local( quota_reset ) elif "quota_resets" in model_stats: del model_stats["quota_resets"] return data def _sort_sequential( self, candidates: List[Tuple[str, int]], credential_priorities: Optional[Dict[str, int]] = None, ) -> List[Tuple[str, int]]: """ Sort credentials for sequential mode with position retention. Credentials maintain their position based on established usage patterns, ensuring that actively-used credentials remain primary until exhausted. Sorting order (within each sort key, lower value = higher priority): 1. Priority tier (lower number = higher priority) 2. Usage count (higher = more established in rotation, maintains position) 3. Last used timestamp (higher = more recent, tiebreaker for stickiness) 4. Credential ID (alphabetical, stable ordering) Args: candidates: List of (credential_id, usage_count) tuples credential_priorities: Optional dict mapping credentials to priority levels Returns: Sorted list of candidates (same format as input) """ if not candidates: return [] if len(candidates) == 1: return candidates def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]: cred, usage_count = item priority = ( credential_priorities.get(cred, 999) if credential_priorities else 999 ) last_used = ( self._usage_data.get(cred, {}).get("last_used_ts", 0) if self._usage_data else 0 ) return ( priority, # ASC: lower priority number = higher priority -usage_count, # DESC: higher usage = more established -last_used, # DESC: more recent = preferred for ties cred, # ASC: stable alphabetical ordering ) sorted_candidates = sorted(candidates, key=sort_key) # Debug logging - show top 3 credentials in ordering if lib_logger.isEnabledFor(logging.DEBUG): order_info = [ f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})" for c, u in sorted_candidates[:3] ] lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}") return sorted_candidates # ========================================================================= # FAIR CYCLE PERSISTENCE # ========================================================================= def _serialize_cycle_state(self) -> Dict[str, Any]: """ Serialize in-memory cycle state for JSON persistence. Converts sets to lists for JSON compatibility. """ result: Dict[str, Any] = {} for provider, tier_data in self._cycle_exhausted.items(): result[provider] = {} for tier_key, tracking_data in tier_data.items(): result[provider][tier_key] = {} for tracking_key, cycle_data in tracking_data.items(): result[provider][tier_key][tracking_key] = { "cycle_started_at": cycle_data.get("cycle_started_at"), "exhausted": list(cycle_data.get("exhausted", set())), } return result def _deserialize_cycle_state(self, data: Dict[str, Any]) -> None: """ Deserialize cycle state from JSON and populate in-memory structure. Converts lists back to sets and validates expired cycles. """ self._cycle_exhausted = {} now_ts = time.time() for provider, tier_data in data.items(): if not isinstance(tier_data, dict): continue self._cycle_exhausted[provider] = {} for tier_key, tracking_data in tier_data.items(): if not isinstance(tracking_data, dict): continue self._cycle_exhausted[provider][tier_key] = {} for tracking_key, cycle_data in tracking_data.items(): if not isinstance(cycle_data, dict): continue cycle_started = cycle_data.get("cycle_started_at") exhausted_list = cycle_data.get("exhausted", []) # Check if cycle has expired if cycle_started is not None: duration = self._get_fair_cycle_duration(provider) if now_ts >= cycle_started + duration: # Cycle expired - skip (don't restore) lib_logger.debug( f"Fair cycle expired for {provider}/{tier_key}/{tracking_key} - not restoring" ) continue # Restore valid cycle self._cycle_exhausted[provider][tier_key][tracking_key] = { "cycle_started_at": cycle_started, "exhausted": set(exhausted_list) if exhausted_list else set(), } # Log restoration summary total_cycles = sum( len(tracking) for tier in self._cycle_exhausted.values() for tracking in tier.values() ) if total_cycles > 0: lib_logger.info(f"Restored {total_cycles} active fair cycle(s) from disk") async def _lazy_init(self): """Initializes the usage data by loading it from the file asynchronously.""" async with self._init_lock: if not self._initialized.is_set(): await self._load_usage() await self._reset_daily_stats_if_needed() self._initialized.set() async def _load_usage(self): """Loads usage data from the JSON file asynchronously with resilience.""" async with self._data_lock: if not os.path.exists(self.file_path): self._usage_data = {} return try: async with aiofiles.open(self.file_path, "r") as f: content = await f.read() self._usage_data = json.loads(content) if content.strip() else {} except FileNotFoundError: # File deleted between exists check and open self._usage_data = {} except json.JSONDecodeError as e: lib_logger.warning( f"Corrupted usage file {self.file_path}: {e}. Starting fresh." ) self._usage_data = {} except (OSError, PermissionError, IOError) as e: lib_logger.warning( f"Cannot read usage file {self.file_path}: {e}. Using empty state." ) self._usage_data = {} # Restore fair cycle state from persisted data fair_cycle_data = self._usage_data.get("__fair_cycle__", {}) if fair_cycle_data: self._deserialize_cycle_state(fair_cycle_data) async def _save_usage(self): """Saves the current usage data using the resilient state writer.""" if self._usage_data is None: return async with self._data_lock: # Add human-readable timestamp fields before saving self._add_readable_timestamps(self._usage_data) # Persist fair cycle state (separate from credential data) if self._cycle_exhausted: self._usage_data["__fair_cycle__"] = self._serialize_cycle_state() elif "__fair_cycle__" in self._usage_data: # Clean up empty cycle data del self._usage_data["__fair_cycle__"] # Hand off to resilient writer - handles retries and disk failures self._state_writer.write(self._usage_data) async def _get_usage_data_snapshot(self) -> Dict[str, Any]: """ Get a shallow copy of the current usage data. Returns: Copy of usage data dict (safe for reading without lock) """ await self._lazy_init() async with self._data_lock: return dict(self._usage_data) if self._usage_data else {} async def get_available_credentials_for_model( self, credentials: List[str], model: str ) -> List[str]: """ Get credentials that are not on cooldown for a specific model. Filters out credentials where: - key_cooldown_until > now (key-level cooldown) - model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted) Args: credentials: List of credential identifiers to check model: Model name to check cooldowns for Returns: List of credentials that are available (not on cooldown) for this model """ await self._lazy_init() now = time.time() available = [] async with self._data_lock: for key in credentials: key_data = self._usage_data.get(key, {}) # Skip if key-level cooldown is active if (key_data.get("key_cooldown_until") or 0) > now: continue # Normalize model name for consistent cooldown lookup # (cooldowns are stored under normalized names by record_failure) # For providers without normalize_model_for_tracking (non-Antigravity), # this returns the model unchanged, so cooldown lookups work as before. normalized_model = self._normalize_model(key, model) # Skip if model-specific cooldown is active if ( key_data.get("model_cooldowns", {}).get(normalized_model) or 0 ) > now: continue available.append(key) return available async def get_credential_availability_stats( self, credentials: List[str], model: str, credential_priorities: Optional[Dict[str, int]] = None, ) -> Dict[str, int]: """ Get credential availability statistics including cooldown and fair cycle exclusions. This is used for logging to show why credentials are excluded. Args: credentials: List of credential identifiers to check model: Model name to check credential_priorities: Optional dict mapping credentials to priorities Returns: Dict with: "total": Total credentials "on_cooldown": Count on cooldown "fair_cycle_excluded": Count excluded by fair cycle "available": Count available for selection """ await self._lazy_init() now = time.time() total = len(credentials) on_cooldown = 0 not_on_cooldown = [] # First pass: check cooldowns async with self._data_lock: for key in credentials: key_data = self._usage_data.get(key, {}) # Check if key-level or model-level cooldown is active normalized_model = self._normalize_model(key, model) if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(normalized_model) or 0 ) > now: on_cooldown += 1 else: not_on_cooldown.append(key) # Second pass: check fair cycle exclusions (only for non-cooldown credentials) fair_cycle_excluded = 0 if not_on_cooldown: provider = self._get_provider_from_credential(not_on_cooldown[0]) if provider: rotation_mode = self._get_rotation_mode(provider) if self._is_fair_cycle_enabled(provider, rotation_mode): # Check each credential against its own tier's exhausted set for key in not_on_cooldown: key_priority = ( credential_priorities.get(key, 999) if credential_priorities else 999 ) tier_key = self._get_tier_key(provider, key_priority) tracking_key = self._get_tracking_key(key, model, provider) if self._is_credential_exhausted_in_cycle( key, provider, tier_key, tracking_key ): fair_cycle_excluded += 1 available = total - on_cooldown - fair_cycle_excluded return { "total": total, "on_cooldown": on_cooldown, "fair_cycle_excluded": fair_cycle_excluded, "available": available, } async def get_soonest_cooldown_end( self, credentials: List[str], model: str, ) -> Optional[float]: """ Find the soonest time when any credential will come off cooldown. This is used for smart waiting logic - if no credentials are available, we can determine whether to wait (if soonest cooldown < deadline) or fail fast (if soonest cooldown > deadline). Args: credentials: List of credential identifiers to check model: Model name to check cooldowns for Returns: Timestamp of soonest cooldown end, or None if no credentials are on cooldown """ await self._lazy_init() now = time.time() soonest_end = None async with self._data_lock: for key in credentials: key_data = self._usage_data.get(key, {}) normalized_model = self._normalize_model(key, model) # Check key-level cooldown key_cooldown = key_data.get("key_cooldown_until") or 0 if key_cooldown > now: if soonest_end is None or key_cooldown < soonest_end: soonest_end = key_cooldown # Check model-level cooldown model_cooldown = ( key_data.get("model_cooldowns", {}).get(normalized_model) or 0 ) if model_cooldown > now: if soonest_end is None or model_cooldown < soonest_end: soonest_end = model_cooldown return soonest_end async def _reset_daily_stats_if_needed(self): """ Checks if usage stats need to be reset for any key. Supports three reset modes: 1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window 2. credential: One window per credential (legacy with custom window duration) 3. daily: Legacy daily reset at daily_reset_time_utc """ if self._usage_data is None: return now_utc = datetime.now(timezone.utc) now_ts = time.time() today_str = now_utc.date().isoformat() needs_saving = False for key, data in self._usage_data.items(): reset_config = self._get_usage_reset_config(key) if reset_config: reset_mode = reset_config.get("mode", "credential") if reset_mode == "per_model": # Per-model window reset needs_saving |= await self._check_per_model_resets( key, data, reset_config, now_ts ) else: # Credential-level window reset (legacy) needs_saving |= await self._check_window_reset( key, data, reset_config, now_ts ) elif self.daily_reset_time_utc: # Legacy daily reset needs_saving |= await self._check_daily_reset( key, data, now_utc, today_str, now_ts ) if needs_saving: await self._save_usage() async def _check_per_model_resets( self, key: str, data: Dict[str, Any], reset_config: Dict[str, Any], now_ts: float, ) -> bool: """ Check and perform per-model resets for a credential. Each model resets independently based on: 1. quota_reset_ts (authoritative, from quota exhausted error) if set 2. window_start_ts + window_seconds (fallback) otherwise Grouped models reset together - all models in a group must be ready. Args: key: Credential identifier data: Usage data for this credential reset_config: Provider's reset configuration now_ts: Current timestamp Returns: True if data was modified and needs saving """ window_seconds = reset_config.get("window_seconds", 86400) models_data = data.get("models", {}) if not models_data: return False modified = False processed_groups = set() for model, model_data in list(models_data.items()): # Check if this model is in a quota group group = self._get_model_quota_group(key, model) if group: if group in processed_groups: continue # Already handled this group # Check if entire group should reset if self._should_group_reset( key, group, models_data, window_seconds, now_ts ): # Archive and reset all models in group grouped_models = self._get_grouped_models(key, group) archived_count = 0 for grouped_model in grouped_models: if grouped_model in models_data: gm_data = models_data[grouped_model] self._archive_model_to_global(data, grouped_model, gm_data) self._reset_model_data(gm_data) archived_count += 1 if archived_count > 0: lib_logger.info( f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}" ) modified = True processed_groups.add(group) else: # Ungrouped model - check individually if self._should_model_reset(model_data, window_seconds, now_ts): self._archive_model_to_global(data, model, model_data) self._reset_model_data(model_data) lib_logger.info(f"Reset model {model} for {mask_credential(key)}") modified = True # Preserve unexpired cooldowns if modified: self._preserve_unexpired_cooldowns(key, data, now_ts) if "failures" in data: data["failures"] = {} return modified def _should_model_reset( self, model_data: Dict[str, Any], window_seconds: int, now_ts: float ) -> bool: """ Check if a single model should reset. Returns True if: - quota_reset_ts is set AND now >= quota_reset_ts, OR - quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds """ quota_reset = model_data.get("quota_reset_ts") window_start = model_data.get("window_start_ts") if quota_reset: return now_ts >= quota_reset elif window_start: return now_ts >= window_start + window_seconds return False def _should_group_reset( self, key: str, group: str, models_data: Dict[str, Dict], window_seconds: int, now_ts: float, ) -> bool: """ Check if all models in a group should reset. All models in the group must be ready to reset. If any model has an active cooldown/window, the whole group waits. """ grouped_models = self._get_grouped_models(key, group) # Track if any model in group has data any_has_data = False for grouped_model in grouped_models: model_data = models_data.get(grouped_model, {}) if not model_data or ( model_data.get("window_start_ts") is None and model_data.get("success_count", 0) == 0 ): continue # No stats for this model yet any_has_data = True if not self._should_model_reset(model_data, window_seconds, now_ts): return False # At least one model not ready return any_has_data def _archive_model_to_global( self, data: Dict[str, Any], model: str, model_data: Dict[str, Any] ) -> None: """Archive a single model's stats to global.""" global_data = data.setdefault("global", {"models": {}}) global_model = global_data["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) global_model["success_count"] += model_data.get("success_count", 0) global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0) global_model["prompt_tokens_cached"] = global_model.get( "prompt_tokens_cached", 0 ) + model_data.get("prompt_tokens_cached", 0) global_model["completion_tokens"] += model_data.get("completion_tokens", 0) global_model["approx_cost"] += model_data.get("approx_cost", 0.0) def _reset_model_data(self, model_data: Dict[str, Any]) -> None: """Reset a model's window and stats.""" model_data["window_start_ts"] = None model_data["quota_reset_ts"] = None model_data["success_count"] = 0 model_data["failure_count"] = 0 model_data["request_count"] = 0 model_data["prompt_tokens"] = 0 model_data["completion_tokens"] = 0 model_data["approx_cost"] = 0.0 # Reset quota baseline fields only if they exist (Antigravity-specific) # These are added by update_quota_baseline(), only called for Antigravity if "baseline_remaining_fraction" in model_data: model_data["baseline_remaining_fraction"] = None model_data["baseline_fetched_at"] = None model_data["requests_at_baseline"] = None # Reset quota display but keep max_requests (it doesn't change between periods) max_req = model_data.get("quota_max_requests") if max_req: model_data["quota_display"] = f"0/{max_req}" async def _check_window_reset( self, key: str, data: Dict[str, Any], reset_config: Dict[str, Any], now_ts: float, ) -> bool: """ Check and perform rolling window reset for a credential. Args: key: Credential identifier data: Usage data for this credential reset_config: Provider's reset configuration now_ts: Current timestamp Returns: True if data was modified and needs saving """ window_seconds = reset_config.get("window_seconds", 86400) # Default 24h field_name = reset_config.get("field_name", "window") description = reset_config.get("description", "rolling window") # Get current window data window_data = data.get(field_name, {}) window_start = window_data.get("start_ts") # No window started yet - nothing to reset if window_start is None: return False # Check if window has expired window_end = window_start + window_seconds if now_ts < window_end: # Window still active return False # Window expired - perform reset hours_elapsed = (now_ts - window_start) / 3600 lib_logger.info( f"Resetting {field_name} for {mask_credential(key)} - " f"{description} expired after {hours_elapsed:.1f}h" ) # Archive to global self._archive_to_global(data, window_data) # Preserve unexpired cooldowns self._preserve_unexpired_cooldowns(key, data, now_ts) # Reset window stats (but don't start new window until first request) data[field_name] = {"start_ts": None, "models": {}} # Reset consecutive failures if "failures" in data: data["failures"] = {} return True async def _check_daily_reset( self, key: str, data: Dict[str, Any], now_utc: datetime, today_str: str, now_ts: float, ) -> bool: """ Check and perform legacy daily reset for a credential. Args: key: Credential identifier data: Usage data for this credential now_utc: Current datetime in UTC today_str: Today's date as ISO string now_ts: Current timestamp Returns: True if data was modified and needs saving """ last_reset_str = data.get("last_daily_reset", "") if last_reset_str == today_str: return False last_reset_dt = None if last_reset_str: try: last_reset_dt = datetime.fromisoformat(last_reset_str).replace( tzinfo=timezone.utc ) except ValueError: pass # Determine the reset threshold for today reset_threshold_today = datetime.combine( now_utc.date(), self.daily_reset_time_utc ) if not ( last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc ): return False lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}") # Preserve unexpired cooldowns self._preserve_unexpired_cooldowns(key, data, now_ts) # Reset consecutive failures if "failures" in data: data["failures"] = {} # Archive daily stats to global daily_data = data.get("daily", {}) if daily_data: self._archive_to_global(data, daily_data) # Reset daily stats data["daily"] = {"date": today_str, "models": {}} data["last_daily_reset"] = today_str return True def _archive_to_global( self, data: Dict[str, Any], source_data: Dict[str, Any] ) -> None: """ Archive usage stats from a source field (daily/window) to global. Args: data: The credential's usage data source_data: The source field data to archive (has "models" key) """ global_data = data.setdefault("global", {"models": {}}) for model, stats in source_data.get("models", {}).items(): global_model_stats = global_data["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) global_model_stats["success_count"] += stats.get("success_count", 0) global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) global_model_stats["prompt_tokens_cached"] = global_model_stats.get( "prompt_tokens_cached", 0 ) + stats.get("prompt_tokens_cached", 0) global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) def _preserve_unexpired_cooldowns( self, key: str, data: Dict[str, Any], now_ts: float ) -> None: """ Preserve unexpired cooldowns during reset (important for long quota cooldowns). Args: key: Credential identifier (for logging) data: The credential's usage data now_ts: Current timestamp """ # Preserve unexpired model cooldowns if "model_cooldowns" in data: active_cooldowns = { model: end_time for model, end_time in data["model_cooldowns"].items() if end_time > now_ts } if active_cooldowns: max_remaining = max( end_time - now_ts for end_time in active_cooldowns.values() ) hours_remaining = max_remaining / 3600 lib_logger.info( f"Preserving {len(active_cooldowns)} active cooldown(s) " f"for key {mask_credential(key)} during reset " f"(longest: {hours_remaining:.1f}h remaining)" ) data["model_cooldowns"] = active_cooldowns else: data["model_cooldowns"] = {} # Preserve unexpired key-level cooldown if data.get("key_cooldown_until"): if data["key_cooldown_until"] <= now_ts: data["key_cooldown_until"] = None else: hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600 lib_logger.info( f"Preserving key-level cooldown for {mask_credential(key)} " f"during reset ({hours_remaining:.1f}h remaining)" ) else: data["key_cooldown_until"] = None def _initialize_key_states(self, keys: List[str]): """Initializes state tracking for all provided keys if not already present.""" for key in keys: if key not in self.key_states: self.key_states[key] = { "lock": asyncio.Lock(), "condition": asyncio.Condition(), "models_in_use": {}, # Dict[model_name, concurrent_count] } def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str: """ Selects a credential using weighted random selection based on usage counts. Args: candidates: List of (credential_id, usage_count) tuples tolerance: Tolerance value for weight calculation Returns: Selected credential ID Formula: weight = (max_usage - credential_usage) + tolerance + 1 This formula ensures: - Lower usage = higher weight = higher selection probability - Tolerance adds variability: higher tolerance means more randomness - The +1 ensures all credentials have at least some chance of selection """ if not candidates: raise ValueError("Cannot select from empty candidate list") if len(candidates) == 1: return candidates[0][0] # Extract usage counts usage_counts = [usage for _, usage in candidates] max_usage = max(usage_counts) # Calculate weights using the formula: (max - current) + tolerance + 1 weights = [] for credential, usage in candidates: weight = (max_usage - usage) + tolerance + 1 weights.append(weight) # Log weight distribution for debugging if lib_logger.isEnabledFor(logging.DEBUG): total_weight = sum(weights) weight_info = ", ".join( f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)" for (cred, _), w in zip(candidates, weights) ) # lib_logger.debug(f"Weighted selection candidates: {weight_info}") # Random selection with weights selected_credential = random.choices( [cred for cred, _ in candidates], weights=weights, k=1 )[0] return selected_credential async def acquire_key( self, available_keys: List[str], model: str, deadline: float, max_concurrent: int = 1, credential_priorities: Optional[Dict[str, int]] = None, credential_tier_names: Optional[Dict[str, str]] = None, all_provider_credentials: Optional[List[str]] = None, ) -> str: """ Acquires the best available key using a tiered, model-aware locking strategy, respecting a global deadline and credential priorities. Priority Logic: - Groups credentials by priority level (1=highest, 2=lower, etc.) - Always tries highest priority (lowest number) first - Within same priority, sorts by usage count (load balancing) - Only moves to next priority if all higher-priority keys exhausted/busy Args: available_keys: List of credential identifiers to choose from model: Model name being requested deadline: Timestamp after which to stop trying max_concurrent: Maximum concurrent requests allowed per credential credential_priorities: Optional dict mapping credentials to priority levels (1=highest) credential_tier_names: Optional dict mapping credentials to tier names (for logging) all_provider_credentials: Full list of provider credentials (used for cycle reset checks) Returns: Selected credential identifier Raises: NoAvailableKeysError: If no key could be acquired within the deadline """ await self._lazy_init() await self._reset_daily_stats_if_needed() self._initialize_key_states(available_keys) # Normalize model name for consistent cooldown lookup # (cooldowns are stored under normalized names by record_failure) # Use first credential for provider detection; all credentials passed here # are for the same provider (filtered by client.py before calling acquire_key). # For providers without normalize_model_for_tracking (non-Antigravity), # this returns the model unchanged, so cooldown lookups work as before. normalized_model = ( self._normalize_model(available_keys[0], model) if available_keys else model ) # This loop continues as long as the global deadline has not been met. while time.time() < deadline: now = time.time() # Group credentials by priority level (if priorities provided) if credential_priorities: # Group keys by priority level priority_groups = {} async with self._data_lock: for key in available_keys: key_data = self._usage_data.get(key, {}) # Skip keys on cooldown (use normalized model for lookup) if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(normalized_model) or 0 ) > now: continue # Get priority for this key (default to 999 if not specified) priority = credential_priorities.get(key, 999) # Get usage count for load balancing within priority groups # Uses grouped usage if model is in a quota group usage_count = self._get_grouped_usage_count(key, model) # Group by priority if priority not in priority_groups: priority_groups[priority] = [] priority_groups[priority].append((key, usage_count)) # Try priority groups in order (1, 2, 3, ...) sorted_priorities = sorted(priority_groups.keys()) for priority_level in sorted_priorities: keys_in_priority = priority_groups[priority_level] # Determine selection method based on provider's rotation mode provider = model.split("/")[0] if "/" in model else "" rotation_mode = self._get_rotation_mode(provider) # Fair cycle filtering if provider and self._is_fair_cycle_enabled( provider, rotation_mode ): tier_key = self._get_tier_key(provider, priority_level) tracking_key = self._get_tracking_key( keys_in_priority[0][0] if keys_in_priority else "", model, provider, ) # Get all credentials for this tier (for cycle completion check) all_tier_creds = self._get_all_credentials_for_tier_key( provider, tier_key, all_provider_credentials or available_keys, credential_priorities, ) # Check if cycle should reset (all exhausted, expired, or none available) if self._should_reset_cycle( provider, tier_key, tracking_key, all_tier_creds, available_not_on_cooldown=[ key for key, _ in keys_in_priority ], ): self._reset_cycle(provider, tier_key, tracking_key) # Filter out exhausted credentials filtered_keys = [] for key, usage_count in keys_in_priority: if not self._is_credential_exhausted_in_cycle( key, provider, tier_key, tracking_key ): filtered_keys.append((key, usage_count)) keys_in_priority = filtered_keys # Calculate effective concurrency based on priority tier multiplier = self._get_priority_multiplier( provider, priority_level, rotation_mode ) effective_max_concurrent = max_concurrent * multiplier # Within each priority group, use existing tier1/tier2 logic tier1_keys, tier2_keys = [], [] for key, usage_count in keys_in_priority: key_state = self.key_states[key] # Tier 1: Completely idle keys (preferred) if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) # Tier 2: Keys that can accept more concurrent requests elif ( key_state["models_in_use"].get(model, 0) < effective_max_concurrent ): tier2_keys.append((key, usage_count)) if rotation_mode == "sequential": # Sequential mode: sort credentials by priority, usage, recency # Keep all candidates in sorted order (no filtering to single key) selection_method = "sequential" if tier1_keys: tier1_keys = self._sort_sequential( tier1_keys, credential_priorities ) if tier2_keys: tier2_keys = self._sort_sequential( tier2_keys, credential_priorities ) elif self.rotation_tolerance > 0: # Balanced mode with weighted randomness selection_method = "weighted-random" if tier1_keys: selected_key = self._select_weighted_random( tier1_keys, self.rotation_tolerance ) tier1_keys = [ (k, u) for k, u in tier1_keys if k == selected_key ] if tier2_keys: selected_key = self._select_weighted_random( tier2_keys, self.rotation_tolerance ) tier2_keys = [ (k, u) for k, u in tier2_keys if k == selected_key ] else: # Deterministic: sort by usage within each tier selection_method = "least-used" tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) # Try to acquire from Tier 1 first for key, usage in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 tier_name = ( credential_tier_names.get(key, "unknown") if credential_tier_names else "unknown" ) quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})" ) return key # Then try Tier 2 for key, usage in tier2_keys: state = self.key_states[key] async with state["lock"]: current_count = state["models_in_use"].get(model, 0) if current_count < effective_max_concurrent: state["models_in_use"][model] = current_count + 1 tier_name = ( credential_tier_names.get(key, "unknown") if credential_tier_names else "unknown" ) quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" ) return key # If we get here, all priority groups were exhausted but keys might become available # Collect all keys across all priorities for waiting all_potential_keys = [] for keys_list in priority_groups.values(): all_potential_keys.extend(keys_list) if not all_potential_keys: # All credentials are on cooldown - check if waiting makes sense soonest_end = await self.get_soonest_cooldown_end( available_keys, model ) if soonest_end is None: # No cooldowns active but no keys available (shouldn't happen) lib_logger.warning( "No keys eligible and no cooldowns active. Re-evaluating..." ) await asyncio.sleep(1) continue remaining_budget = deadline - time.time() wait_needed = soonest_end - time.time() if wait_needed > remaining_budget: # Fail fast - no credential will be available in time lib_logger.warning( f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " f"but only {remaining_budget:.1f}s budget remaining. Failing fast." ) break # Exit loop, will raise NoAvailableKeysError # Wait for the credential to become available lib_logger.info( f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." ) await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) continue # Wait for the highest priority key with lowest usage best_priority = min(priority_groups.keys()) best_priority_keys = priority_groups[best_priority] best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] lib_logger.info( f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..." ) else: # Original logic when no priorities specified # Determine selection method based on provider's rotation mode provider = model.split("/")[0] if "/" in model else "" rotation_mode = self._get_rotation_mode(provider) # Calculate effective concurrency for default priority (999) # When no priorities are specified, all credentials get default priority default_priority = 999 multiplier = self._get_priority_multiplier( provider, default_priority, rotation_mode ) effective_max_concurrent = max_concurrent * multiplier tier1_keys, tier2_keys = [], [] # First, filter the list of available keys to exclude any on cooldown. async with self._data_lock: for key in available_keys: key_data = self._usage_data.get(key, {}) # Skip keys on cooldown (use normalized model for lookup) if (key_data.get("key_cooldown_until") or 0) > now or ( key_data.get("model_cooldowns", {}).get(normalized_model) or 0 ) > now: continue # Prioritize keys based on their current usage to ensure load balancing. # Uses grouped usage if model is in a quota group usage_count = self._get_grouped_usage_count(key, model) key_state = self.key_states[key] # Tier 1: Completely idle keys (preferred). if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) # Tier 2: Keys that can accept more concurrent requests for this model. elif ( key_state["models_in_use"].get(model, 0) < effective_max_concurrent ): tier2_keys.append((key, usage_count)) # Fair cycle filtering (non-priority case) if provider and self._is_fair_cycle_enabled(provider, rotation_mode): tier_key = self._get_tier_key(provider, default_priority) tracking_key = self._get_tracking_key( available_keys[0] if available_keys else "", model, provider, ) # Get all credentials for this tier (for cycle completion check) all_tier_creds = self._get_all_credentials_for_tier_key( provider, tier_key, all_provider_credentials or available_keys, None, ) # Check if cycle should reset (all exhausted, expired, or none available) if self._should_reset_cycle( provider, tier_key, tracking_key, all_tier_creds, available_not_on_cooldown=[ key for key, _ in (tier1_keys + tier2_keys) ], ): self._reset_cycle(provider, tier_key, tracking_key) # Filter out exhausted credentials from both tiers tier1_keys = [ (key, usage) for key, usage in tier1_keys if not self._is_credential_exhausted_in_cycle( key, provider, tier_key, tracking_key ) ] tier2_keys = [ (key, usage) for key, usage in tier2_keys if not self._is_credential_exhausted_in_cycle( key, provider, tier_key, tracking_key ) ] if rotation_mode == "sequential": # Sequential mode: sort credentials by priority, usage, recency # Keep all candidates in sorted order (no filtering to single key) selection_method = "sequential" if tier1_keys: tier1_keys = self._sort_sequential( tier1_keys, credential_priorities ) if tier2_keys: tier2_keys = self._sort_sequential( tier2_keys, credential_priorities ) elif self.rotation_tolerance > 0: # Balanced mode with weighted randomness selection_method = "weighted-random" if tier1_keys: selected_key = self._select_weighted_random( tier1_keys, self.rotation_tolerance ) tier1_keys = [ (k, u) for k, u in tier1_keys if k == selected_key ] if tier2_keys: selected_key = self._select_weighted_random( tier2_keys, self.rotation_tolerance ) tier2_keys = [ (k, u) for k, u in tier2_keys if k == selected_key ] else: # Deterministic: sort by usage within each tier selection_method = "least-used" tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) # Attempt to acquire a key from Tier 1 first. for key, usage in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"][model] = 1 tier_name = ( credential_tier_names.get(key) if credential_tier_names else None ) tier_info = f"tier: {tier_name}, " if tier_name else "" quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"({tier_info}selection: {selection_method}, {quota_display})" ) return key # If no Tier 1 keys are available, try Tier 2. for key, usage in tier2_keys: state = self.key_states[key] async with state["lock"]: current_count = state["models_in_use"].get(model, 0) if current_count < effective_max_concurrent: state["models_in_use"][model] = current_count + 1 tier_name = ( credential_tier_names.get(key) if credential_tier_names else None ) tier_info = f"tier: {tier_name}, " if tier_name else "" quota_display = self._get_quota_display(key, model) lib_logger.info( f"Acquired key {mask_credential(key)} for model {model} " f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})" ) return key # If all eligible keys are locked, wait for a key to be released. lib_logger.info( "All eligible keys are currently locked for this model. Waiting..." ) all_potential_keys = tier1_keys + tier2_keys if not all_potential_keys: # All credentials are on cooldown - check if waiting makes sense soonest_end = await self.get_soonest_cooldown_end( available_keys, model ) if soonest_end is None: # No cooldowns active but no keys available (shouldn't happen) lib_logger.warning( "No keys eligible and no cooldowns active. Re-evaluating..." ) await asyncio.sleep(1) continue remaining_budget = deadline - time.time() wait_needed = soonest_end - time.time() if wait_needed > remaining_budget: # Fail fast - no credential will be available in time lib_logger.warning( f"All credentials on cooldown. Soonest available in {wait_needed:.1f}s, " f"but only {remaining_budget:.1f}s budget remaining. Failing fast." ) break # Exit loop, will raise NoAvailableKeysError # Wait for the credential to become available lib_logger.info( f"All credentials on cooldown. Waiting {wait_needed:.1f}s for soonest credential..." ) await asyncio.sleep(min(wait_needed + 0.1, remaining_budget)) continue # Wait on the condition of the key with the lowest current usage. best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] try: async with wait_condition: remaining_budget = deadline - time.time() if remaining_budget <= 0: break # Exit if the budget has already been exceeded. # Wait for a notification, but no longer than the remaining budget or 1 second. await asyncio.wait_for( wait_condition.wait(), timeout=min(1, remaining_budget) ) lib_logger.info("Notified that a key was released. Re-evaluating...") except asyncio.TimeoutError: # This is not an error, just a timeout for the wait. The main loop will re-evaluate. lib_logger.info("Wait timed out. Re-evaluating for any available key.") # If the loop exits, it means the deadline was exceeded. raise NoAvailableKeysError( f"Could not acquire a key for model {model} within the global time budget." ) async def release_key(self, key: str, model: str): """Releases a key's lock for a specific model and notifies waiting tasks.""" if key not in self.key_states: return state = self.key_states[key] async with state["lock"]: if model in state["models_in_use"]: state["models_in_use"][model] -= 1 remaining = state["models_in_use"][model] if remaining <= 0: del state["models_in_use"][model] # Clean up when count reaches 0 lib_logger.info( f"Released credential {mask_credential(key)} from model {model} " f"(remaining concurrent: {max(0, remaining)})" ) else: lib_logger.warning( f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use." ) # Notify all tasks waiting on this key's condition async with state["condition"]: state["condition"].notify_all() async def record_success( self, key: str, model: str, completion_response: Optional[litellm.ModelResponse] = None, ): """ Records a successful API call, resetting failure counters. It safely handles cases where token usage data is not available. Supports two modes based on provider configuration: - per_model: Each model has its own window_start_ts and stats in key_data["models"] - credential: Legacy mode with key_data["daily"]["models"] """ await self._lazy_init() # Normalize model name to public-facing name for consistent tracking model = self._normalize_model(key, model) async with self._data_lock: now_ts = time.time() today_utc_str = datetime.now(timezone.utc).date().isoformat() reset_config = self._get_usage_reset_config(key) reset_mode = ( reset_config.get("mode", "credential") if reset_config else "credential" ) if reset_mode == "per_model": # New per-model structure key_data = self._usage_data.setdefault( key, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Ensure models dict exists if "models" not in key_data: key_data["models"] = {} # Get or create per-model data with window tracking model_data = key_data["models"].setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) # Start window on first request for this model if model_data.get("window_start_ts") is None: model_data["window_start_ts"] = now_ts # Set expected quota reset time from provider config window_seconds = ( reset_config.get("window_seconds", 0) if reset_config else 0 ) if window_seconds > 0: model_data["quota_reset_ts"] = now_ts + window_seconds window_hours = window_seconds / 3600 if window_seconds else 0 lib_logger.info( f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}" ) # Record stats model_data["success_count"] += 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 # Sync request_count across quota group (for providers with shared quota pools) new_request_count = model_data["request_count"] group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = key_data["models"].setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) other_model_data["request_count"] = new_request_count # Sync window timing (shared quota pool = shared window) window_start = model_data.get("window_start_ts") if window_start: other_model_data["window_start_ts"] = window_start # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: other_model_data["quota_max_requests"] = max_req other_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) # Update quota_display if max_requests is set (Antigravity-specific) max_req = model_data.get("quota_max_requests") if max_req: model_data["quota_display"] = ( f"{model_data['request_count']}/{max_req}" ) # Check custom cap if self._check_and_apply_custom_cap( key, model, model_data["request_count"] ): # Custom cap exceeded, cooldown applied # Continue to record tokens/cost but credential will be skipped next time pass usage_data_ref = model_data # For token/cost recording below else: # Legacy credential-level structure key_data = self._usage_data.setdefault( key, { "daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) if "last_daily_reset" not in key_data: key_data["last_daily_reset"] = today_utc_str # Get or create model data in daily structure usage_data_ref = key_data["daily"]["models"].setdefault( model, { "success_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) usage_data_ref["success_count"] += 1 # Reset failures for this model model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) model_failures["consecutive_failures"] = 0 # Clear transient cooldown on success (but NOT quota_reset_ts) if model in key_data.get("model_cooldowns", {}): del key_data["model_cooldowns"][model] # Record token and cost usage if ( completion_response and hasattr(completion_response, "usage") and completion_response.usage ): usage = completion_response.usage prompt_total = usage.prompt_tokens # Extract cached tokens from prompt_tokens_details if present cached_tokens = 0 prompt_details = getattr(usage, "prompt_tokens_details", None) if prompt_details: if isinstance(prompt_details, dict): cached_tokens = prompt_details.get("cached_tokens", 0) or 0 elif hasattr(prompt_details, "cached_tokens"): cached_tokens = prompt_details.cached_tokens or 0 # Store uncached tokens (prompt_tokens is total, subtract cached) uncached_tokens = prompt_total - cached_tokens usage_data_ref["prompt_tokens"] += uncached_tokens # Store cached tokens separately if cached_tokens > 0: usage_data_ref["prompt_tokens_cached"] = ( usage_data_ref.get("prompt_tokens_cached", 0) + cached_tokens ) usage_data_ref["completion_tokens"] += getattr( usage, "completion_tokens", 0 ) lib_logger.info( f"Recorded usage from response object for key {mask_credential(key)}" ) try: provider_name = model.split("/")[0] provider_instance = self._get_provider_instance(provider_name) if provider_instance and getattr( provider_instance, "skip_cost_calculation", False ): lib_logger.debug( f"Skipping cost calculation for provider '{provider_name}' (custom provider)." ) else: if isinstance(completion_response, litellm.EmbeddingResponse): model_info = litellm.get_model_info(model) input_cost = model_info.get("input_cost_per_token") if input_cost: cost = ( completion_response.usage.prompt_tokens * input_cost ) else: cost = None else: cost = litellm.completion_cost( completion_response=completion_response, model=model ) if cost is not None: usage_data_ref["approx_cost"] += cost except Exception as e: lib_logger.warning( f"Could not calculate cost for model {model}: {e}" ) elif isinstance(completion_response, asyncio.Future) or hasattr( completion_response, "__aiter__" ): pass # Stream - usage recorded from chunks else: lib_logger.warning( f"No usage data found in completion response for model {model}. Recording success without token count." ) key_data["last_used_ts"] = now_ts await self._save_usage() async def record_failure( self, key: str, model: str, classified_error: ClassifiedError, increment_consecutive_failures: bool = True, ): """Records a failure and applies cooldowns based on error type. Distinguishes between: - quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp) Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time - rate_limit: Short transient cooldown (just wait and retry) Only sets model_cooldowns - does NOT affect stats reset timing Args: key: The API key or credential identifier model: The model name classified_error: The classified error object increment_consecutive_failures: Whether to increment the failure counter. Set to False for provider-level errors that shouldn't count against the key. """ await self._lazy_init() # Normalize model name to public-facing name for consistent tracking model = self._normalize_model(key, model) async with self._data_lock: now_ts = time.time() today_utc_str = datetime.now(timezone.utc).date().isoformat() reset_config = self._get_usage_reset_config(key) reset_mode = ( reset_config.get("mode", "credential") if reset_config else "credential" ) # Initialize key data with appropriate structure if reset_mode == "per_model": key_data = self._usage_data.setdefault( key, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) else: key_data = self._usage_data.setdefault( key, { "daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Provider-level errors (transient issues) should not count against the key provider_level_errors = {"server_error", "api_connection"} # Determine if we should increment the failure counter should_increment = ( increment_consecutive_failures and classified_error.error_type not in provider_level_errors ) # Calculate cooldown duration based on error type cooldown_seconds = None model_cooldowns = key_data.setdefault("model_cooldowns", {}) # Capture existing cooldown BEFORE we modify it # Used to determine if this is a fresh exhaustion vs re-processing existing_cooldown_before = model_cooldowns.get(model) was_already_on_cooldown = ( existing_cooldown_before is not None and existing_cooldown_before > now_ts ) if classified_error.error_type == "quota_exceeded": # Quota exhausted - use authoritative reset timestamp if available quota_reset_ts = classified_error.quota_reset_timestamp cooldown_seconds = ( classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT ) if quota_reset_ts and reset_mode == "per_model": # Set quota_reset_ts on model - this becomes authoritative stats reset time models_data = key_data.setdefault("models", {}) model_data = models_data.setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) model_data["quota_reset_ts"] = quota_reset_ts # Track failure for quota estimation (request still consumes quota) model_data["failure_count"] = model_data.get("failure_count", 0) + 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 # Clamp request_count to quota_max_requests when quota is exhausted # This prevents display overflow (e.g., 151/150) when requests are # counted locally before API refresh corrects the value max_req = model_data.get("quota_max_requests") if max_req is not None and model_data["request_count"] > max_req: model_data["request_count"] = max_req # Update quota_display with clamped value model_data["quota_display"] = f"{max_req}/{max_req}" new_request_count = model_data["request_count"] # Apply to all models in the same quota group group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: group_model_data = models_data.setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) group_model_data["quota_reset_ts"] = quota_reset_ts # Sync request_count across quota group group_model_data["request_count"] = new_request_count # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: group_model_data["quota_max_requests"] = max_req group_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) # Also set transient cooldown for selection logic model_cooldowns[grouped_model] = quota_reset_ts reset_dt = datetime.fromtimestamp( quota_reset_ts, tz=timezone.utc ) lib_logger.info( f"Quota exhausted for group '{group}' ({len(grouped_models)} models) " f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}" ) else: reset_dt = datetime.fromtimestamp( quota_reset_ts, tz=timezone.utc ) hours = (quota_reset_ts - now_ts) / 3600 lib_logger.info( f"Quota exhausted for model {model} on {mask_credential(key)}. " f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)" ) # Set transient cooldown for selection logic model_cooldowns[model] = quota_reset_ts else: # No authoritative timestamp or legacy mode - just use retry_after model_cooldowns[model] = now_ts + cooldown_seconds hours = cooldown_seconds / 3600 lib_logger.info( f"Quota exhausted on {mask_credential(key)} for model {model}. " f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)" ) # Mark credential as exhausted for fair cycle if cooldown exceeds threshold # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) # This prevents re-marking after cycle reset if not was_already_on_cooldown: effective_cooldown = ( (quota_reset_ts - now_ts) if quota_reset_ts else (cooldown_seconds or 0) ) provider = self._get_provider_from_credential(key) if provider: threshold = self._get_exhaustion_cooldown_threshold(provider) if effective_cooldown > threshold: rotation_mode = self._get_rotation_mode(provider) if self._is_fair_cycle_enabled(provider, rotation_mode): priority = self._get_credential_priority(key, provider) tier_key = self._get_tier_key(provider, priority) tracking_key = self._get_tracking_key( key, model, provider ) self._mark_credential_exhausted( key, provider, tier_key, tracking_key ) elif classified_error.error_type == "rate_limit": # Transient rate limit - just set short cooldown (does NOT set quota_reset_ts) cooldown_seconds = ( classified_error.retry_after or COOLDOWN_RATE_LIMIT_DEFAULT ) model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.info( f"Rate limit on {mask_credential(key)} for model {model}. " f"Transient cooldown: {cooldown_seconds}s" ) elif classified_error.error_type == "authentication": # Apply a 5-minute key-level lockout for auth errors key_data["key_cooldown_until"] = now_ts + COOLDOWN_AUTH_ERROR cooldown_seconds = COOLDOWN_AUTH_ERROR model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.warning( f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout." ) # If we should increment failures, calculate escalating backoff if should_increment: failures_data = key_data.setdefault("failures", {}) model_failures = failures_data.setdefault( model, {"consecutive_failures": 0} ) model_failures["consecutive_failures"] += 1 count = model_failures["consecutive_failures"] # If cooldown wasn't set by specific error type, use escalating backoff if cooldown_seconds is None: cooldown_seconds = COOLDOWN_BACKOFF_TIERS.get( count, COOLDOWN_BACKOFF_MAX ) model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.warning( f"Failure #{count} for key {mask_credential(key)} with model {model}. " f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s" ) else: # Provider-level errors: apply short cooldown but don't count against key if cooldown_seconds is None: cooldown_seconds = COOLDOWN_TRANSIENT_ERROR model_cooldowns[model] = now_ts + cooldown_seconds lib_logger.info( f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} " f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s" ) # Check for key-level lockout condition await self._check_key_lockout(key, key_data) # Track failure count for quota estimation (all failures consume quota) # This is separate from consecutive_failures which is for backoff logic if reset_mode == "per_model": models_data = key_data.setdefault("models", {}) model_data = models_data.setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) # Only increment if not already incremented in quota_exceeded branch if classified_error.error_type != "quota_exceeded": model_data["failure_count"] = model_data.get("failure_count", 0) + 1 model_data["request_count"] = model_data.get("request_count", 0) + 1 # Sync request_count across quota group new_request_count = model_data["request_count"] group = self._get_model_quota_group(key, model) if group: grouped_models = self._get_grouped_models(key, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = models_data.setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) other_model_data["request_count"] = new_request_count # Also sync quota_max_requests if set max_req = model_data.get("quota_max_requests") if max_req: other_model_data["quota_max_requests"] = max_req other_model_data["quota_display"] = ( f"{new_request_count}/{max_req}" ) key_data["last_failure"] = { "timestamp": now_ts, "model": model, "error": str(classified_error.original_exception), } await self._save_usage() async def update_quota_baseline( self, credential: str, model: str, remaining_fraction: float, max_requests: Optional[int] = None, reset_timestamp: Optional[float] = None, ) -> Optional[Dict[str, Any]]: """ Update quota baseline data for a credential/model after fetching from API. This stores the current quota state as a baseline, which is used to estimate remaining quota based on subsequent request counts. When quota is exhausted (remaining_fraction <= 0.0) and a valid reset_timestamp is provided, this also sets model_cooldowns to prevent wasted requests. Args: credential: Credential identifier (file path or env:// URI) model: Model name (with or without provider prefix) remaining_fraction: Current remaining quota as fraction (0.0 to 1.0) max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude) reset_timestamp: Unix timestamp when quota resets. Only trusted when remaining_fraction < 1.0 (quota has been used). API returns garbage reset times for unused quota (100%). Returns: None if no cooldown was set/updated, otherwise: { "group_or_model": str, # quota group name or model name if ungrouped "hours_until_reset": float, } """ await self._lazy_init() async with self._data_lock: now_ts = time.time() # Get or create key data structure key_data = self._usage_data.setdefault( credential, { "models": {}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}, }, ) # Ensure models dict exists if "models" not in key_data: key_data["models"] = {} # Get or create per-model data model_data = key_data["models"].setdefault( model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, "baseline_remaining_fraction": None, "baseline_fetched_at": None, "requests_at_baseline": None, }, ) # Calculate actual used requests from API's remaining fraction # The API is authoritative - sync our local count to match reality if max_requests is not None: used_requests = int((1.0 - remaining_fraction) * max_requests) else: # Estimate max_requests from provider's quota cost # This matches how get_max_requests_for_model() calculates it provider = self._get_provider_from_credential(credential) plugin_instance = self._get_provider_instance(provider) if plugin_instance and hasattr( plugin_instance, "get_max_requests_for_model" ): # Get tier from provider's cache tier = getattr(plugin_instance, "project_tier_cache", {}).get( credential, "standard-tier" ) # Strip provider prefix from model if present clean_model = model.split("/")[-1] if "/" in model else model max_requests = plugin_instance.get_max_requests_for_model( clean_model, tier ) used_requests = int((1.0 - remaining_fraction) * max_requests) else: # Fallback: keep existing count if we can't calculate used_requests = model_data.get("request_count", 0) max_requests = model_data.get("quota_max_requests") # Sync local request count to API's authoritative value # Use max() to prevent API from resetting our count if it returns stale/cached 100% # The API can only increase our count (if we missed requests), not decrease it # See: https://github.com/Mirrowel/LLM-API-Key-Proxy/issues/75 current_count = model_data.get("request_count", 0) synced_count = max(current_count, used_requests) model_data["request_count"] = synced_count model_data["requests_at_baseline"] = synced_count # Update baseline fields model_data["baseline_remaining_fraction"] = remaining_fraction model_data["baseline_fetched_at"] = now_ts # Update max_requests and quota_display if max_requests is not None: model_data["quota_max_requests"] = max_requests model_data["quota_display"] = f"{synced_count}/{max_requests}" # Handle reset_timestamp: only trust it when quota has been used (< 100%) # API returns garbage reset times for unused quota valid_reset_ts = ( reset_timestamp is not None and remaining_fraction < 1.0 and reset_timestamp > now_ts ) if valid_reset_ts: model_data["quota_reset_ts"] = reset_timestamp # Set cooldowns when quota is exhausted model_cooldowns = key_data.setdefault("model_cooldowns", {}) is_exhausted = remaining_fraction <= 0.0 cooldown_set_info = ( None # Will be returned if cooldown was newly set/updated ) if is_exhausted and valid_reset_ts: # Check if there was an existing ACTIVE cooldown before we update # This distinguishes between fresh exhaustion vs refresh of existing state existing_cooldown = model_cooldowns.get(model) was_already_on_cooldown = ( existing_cooldown is not None and existing_cooldown > now_ts ) # Only update cooldown if not set or differs by more than 5 minutes should_update = ( existing_cooldown is None or abs(existing_cooldown - reset_timestamp) > 300 ) if should_update: model_cooldowns[model] = reset_timestamp hours_until_reset = (reset_timestamp - now_ts) / 3600 # Determine group or model name for logging group = self._get_model_quota_group(credential, model) cooldown_set_info = { "group_or_model": group if group else model.split("/")[-1], "hours_until_reset": hours_until_reset, } # Mark credential as exhausted in fair cycle if cooldown exceeds threshold # BUT only if this is a FRESH exhaustion (wasn't already on cooldown) # This prevents re-marking after cycle reset when quota refresh sees existing cooldown if not was_already_on_cooldown: cooldown_duration = reset_timestamp - now_ts provider = self._get_provider_from_credential(credential) if provider: threshold = self._get_exhaustion_cooldown_threshold(provider) if cooldown_duration > threshold: rotation_mode = self._get_rotation_mode(provider) if self._is_fair_cycle_enabled(provider, rotation_mode): priority = self._get_credential_priority( credential, provider ) tier_key = self._get_tier_key(provider, priority) tracking_key = self._get_tracking_key( credential, model, provider ) self._mark_credential_exhausted( credential, provider, tier_key, tracking_key ) # Defensive clamp: ensure request_count doesn't exceed max when exhausted if ( max_requests is not None and model_data["request_count"] > max_requests ): model_data["request_count"] = max_requests model_data["quota_display"] = f"{max_requests}/{max_requests}" # Sync baseline fields and quota info across quota group group = self._get_model_quota_group(credential, model) if group: grouped_models = self._get_grouped_models(credential, group) for grouped_model in grouped_models: if grouped_model != model: other_model_data = key_data["models"].setdefault( grouped_model, { "window_start_ts": None, "quota_reset_ts": None, "success_count": 0, "failure_count": 0, "request_count": 0, "prompt_tokens": 0, "prompt_tokens_cached": 0, "completion_tokens": 0, "approx_cost": 0.0, }, ) # Sync request tracking (use synced_count to prevent reset bug) other_model_data["request_count"] = synced_count if max_requests is not None: other_model_data["quota_max_requests"] = max_requests other_model_data["quota_display"] = ( f"{synced_count}/{max_requests}" ) # Sync baseline fields other_model_data["baseline_remaining_fraction"] = ( remaining_fraction ) other_model_data["baseline_fetched_at"] = now_ts other_model_data["requests_at_baseline"] = synced_count # Sync reset timestamp if valid if valid_reset_ts: other_model_data["quota_reset_ts"] = reset_timestamp # Sync window start time window_start = model_data.get("window_start_ts") if window_start: other_model_data["window_start_ts"] = window_start # Sync cooldown if exhausted (with ±5 min check) if is_exhausted and valid_reset_ts: existing_grouped = model_cooldowns.get(grouped_model) should_update_grouped = ( existing_grouped is None or abs(existing_grouped - reset_timestamp) > 300 ) if should_update_grouped: model_cooldowns[grouped_model] = reset_timestamp # Defensive clamp for grouped models when exhausted if ( max_requests is not None and other_model_data["request_count"] > max_requests ): other_model_data["request_count"] = max_requests other_model_data["quota_display"] = ( f"{max_requests}/{max_requests}" ) lib_logger.debug( f"Updated quota baseline for {mask_credential(credential)} model={model}: " f"remaining={remaining_fraction:.2%}, synced_request_count={synced_count}" ) await self._save_usage() return cooldown_set_info async def _check_key_lockout(self, key: str, key_data: Dict): """ Checks if a key should be locked out due to multiple model failures. NOTE: This check is currently disabled. The original logic counted individual models in long-term lockout, but this caused issues with quota groups - when a single quota group (e.g., "claude" with 5 models) was exhausted, it would count as 5 lockouts and trigger key-level lockout, blocking other quota groups (like gemini) that were still available. The per-model and per-group cooldowns already handle quota exhaustion properly. """ # Disabled - see docstring above pass async def get_stats_for_endpoint( self, provider_filter: Optional[str] = None, include_global: bool = True, ) -> Dict[str, Any]: """ Get usage stats formatted for the /v1/quota-stats endpoint. Aggregates data from key_usage.json grouped by provider. Includes both current period stats and global (lifetime) stats. Args: provider_filter: If provided, only return stats for this provider include_global: If True, include global/lifetime stats alongside current Returns: { "providers": { "provider_name": { "credential_count": int, "active_count": int, "on_cooldown_count": int, "total_requests": int, "tokens": { "input_cached": int, "input_uncached": int, "input_cache_pct": float, "output": int }, "approx_cost": float | None, "credentials": [...], "global": {...} # If include_global is True } }, "summary": {...}, "global_summary": {...}, # If include_global is True "timestamp": float } """ await self._lazy_init() now_ts = time.time() providers: Dict[str, Dict[str, Any]] = {} # Track global stats separately global_providers: Dict[str, Dict[str, Any]] = {} async with self._data_lock: if not self._usage_data: return { "providers": {}, "summary": { "total_providers": 0, "total_credentials": 0, "active_credentials": 0, "exhausted_credentials": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_total_cost": 0.0, }, "global_summary": { "total_providers": 0, "total_credentials": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_total_cost": 0.0, }, "data_source": "cache", "timestamp": now_ts, } for credential, cred_data in self._usage_data.items(): # Extract provider from credential path provider = self._get_provider_from_credential(credential) if not provider: continue # Apply filter if specified if provider_filter and provider != provider_filter: continue # Initialize provider entry if provider not in providers: providers[provider] = { "credential_count": 0, "active_count": 0, "on_cooldown_count": 0, "exhausted_count": 0, "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_cost": 0.0, "credentials": [], } global_providers[provider] = { "total_requests": 0, "tokens": { "input_cached": 0, "input_uncached": 0, "input_cache_pct": 0, "output": 0, }, "approx_cost": 0.0, } prov_stats = providers[provider] prov_stats["credential_count"] += 1 # Determine credential status and cooldowns key_cooldown = cred_data.get("key_cooldown_until", 0) or 0 model_cooldowns = cred_data.get("model_cooldowns", {}) # Build active cooldowns with remaining time active_cooldowns = {} for model, cooldown_ts in model_cooldowns.items(): if cooldown_ts > now_ts: remaining_seconds = int(cooldown_ts - now_ts) active_cooldowns[model] = { "until_ts": cooldown_ts, "remaining_seconds": remaining_seconds, } key_cooldown_remaining = None if key_cooldown > now_ts: key_cooldown_remaining = int(key_cooldown - now_ts) has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0 # Check if exhausted (all quota groups exhausted for Antigravity) is_exhausted = False models_data = cred_data.get("models", {}) if models_data: # Check if any model has remaining quota all_exhausted = True for model_stats in models_data.values(): if isinstance(model_stats, dict): baseline = model_stats.get("baseline_remaining_fraction") if baseline is None or baseline > 0: all_exhausted = False break if all_exhausted and len(models_data) > 0: is_exhausted = True if is_exhausted: prov_stats["exhausted_count"] += 1 status = "exhausted" elif has_active_cooldown: prov_stats["on_cooldown_count"] += 1 status = "cooldown" else: prov_stats["active_count"] += 1 status = "active" # Aggregate token stats (current period) cred_tokens = { "input_cached": 0, "input_uncached": 0, "output": 0, } cred_requests = 0 cred_cost = 0.0 # Aggregate global token stats cred_global_tokens = { "input_cached": 0, "input_uncached": 0, "output": 0, } cred_global_requests = 0 cred_global_cost = 0.0 # Handle per-model structure (current period) if models_data: for model_name, model_stats in models_data.items(): if not isinstance(model_stats, dict): continue # Prefer request_count if available and non-zero, else fall back to success+failure req_count = model_stats.get("request_count", 0) if req_count > 0: cred_requests += req_count else: cred_requests += model_stats.get("success_count", 0) cred_requests += model_stats.get("failure_count", 0) # Token stats - track cached separately cred_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_tokens["input_uncached"] += model_stats.get( "prompt_tokens", 0 ) cred_tokens["output"] += model_stats.get("completion_tokens", 0) cred_cost += model_stats.get("approx_cost", 0.0) # Handle legacy daily structure daily_data = cred_data.get("daily", {}) daily_models = daily_data.get("models", {}) for model_name, model_stats in daily_models.items(): if not isinstance(model_stats, dict): continue cred_requests += model_stats.get("success_count", 0) cred_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0) cred_tokens["output"] += model_stats.get("completion_tokens", 0) cred_cost += model_stats.get("approx_cost", 0.0) # Handle global stats global_data = cred_data.get("global", {}) global_models = global_data.get("models", {}) for model_name, model_stats in global_models.items(): if not isinstance(model_stats, dict): continue cred_global_requests += model_stats.get("success_count", 0) cred_global_tokens["input_cached"] += model_stats.get( "prompt_tokens_cached", 0 ) cred_global_tokens["input_uncached"] += model_stats.get( "prompt_tokens", 0 ) cred_global_tokens["output"] += model_stats.get( "completion_tokens", 0 ) cred_global_cost += model_stats.get("approx_cost", 0.0) # Add current period stats to global totals cred_global_requests += cred_requests cred_global_tokens["input_cached"] += cred_tokens["input_cached"] cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"] cred_global_tokens["output"] += cred_tokens["output"] cred_global_cost += cred_cost # Build credential entry # Mask credential identifier for display if credential.startswith("env://"): identifier = credential else: identifier = Path(credential).name cred_entry = { "identifier": identifier, "full_path": credential, "status": status, "last_used_ts": cred_data.get("last_used_ts"), "requests": cred_requests, "tokens": cred_tokens, "approx_cost": cred_cost if cred_cost > 0 else None, } # Add cooldown info if key_cooldown_remaining is not None: cred_entry["key_cooldown_remaining"] = key_cooldown_remaining if active_cooldowns: cred_entry["model_cooldowns"] = active_cooldowns # Add global stats for this credential if include_global: # Calculate global cache percentage global_total_input = ( cred_global_tokens["input_cached"] + cred_global_tokens["input_uncached"] ) global_cache_pct = ( round( cred_global_tokens["input_cached"] / global_total_input * 100, 1, ) if global_total_input > 0 else 0 ) cred_entry["global"] = { "requests": cred_global_requests, "tokens": { "input_cached": cred_global_tokens["input_cached"], "input_uncached": cred_global_tokens["input_uncached"], "input_cache_pct": global_cache_pct, "output": cred_global_tokens["output"], }, "approx_cost": cred_global_cost if cred_global_cost > 0 else None, } # Add model-specific data for providers with per-model tracking if models_data: cred_entry["models"] = {} for model_name, model_stats in models_data.items(): if not isinstance(model_stats, dict): continue cred_entry["models"][model_name] = { "requests": model_stats.get("success_count", 0) + model_stats.get("failure_count", 0), "request_count": model_stats.get("request_count", 0), "success_count": model_stats.get("success_count", 0), "failure_count": model_stats.get("failure_count", 0), "prompt_tokens": model_stats.get("prompt_tokens", 0), "prompt_tokens_cached": model_stats.get( "prompt_tokens_cached", 0 ), "completion_tokens": model_stats.get( "completion_tokens", 0 ), "approx_cost": model_stats.get("approx_cost", 0.0), "window_start_ts": model_stats.get("window_start_ts"), "quota_reset_ts": model_stats.get("quota_reset_ts"), # Quota baseline fields (Antigravity-specific) "baseline_remaining_fraction": model_stats.get( "baseline_remaining_fraction" ), "baseline_fetched_at": model_stats.get( "baseline_fetched_at" ), "quota_max_requests": model_stats.get("quota_max_requests"), "quota_display": model_stats.get("quota_display"), } prov_stats["credentials"].append(cred_entry) # Aggregate to provider totals (current period) prov_stats["total_requests"] += cred_requests prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"] prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"] prov_stats["tokens"]["output"] += cred_tokens["output"] if cred_cost > 0: prov_stats["approx_cost"] += cred_cost # Aggregate to global provider totals global_providers[provider]["total_requests"] += cred_global_requests global_providers[provider]["tokens"]["input_cached"] += ( cred_global_tokens["input_cached"] ) global_providers[provider]["tokens"]["input_uncached"] += ( cred_global_tokens["input_uncached"] ) global_providers[provider]["tokens"]["output"] += cred_global_tokens[ "output" ] global_providers[provider]["approx_cost"] += cred_global_cost # Calculate cache percentages for each provider for provider, prov_stats in providers.items(): total_input = ( prov_stats["tokens"]["input_cached"] + prov_stats["tokens"]["input_uncached"] ) if total_input > 0: prov_stats["tokens"]["input_cache_pct"] = round( prov_stats["tokens"]["input_cached"] / total_input * 100, 1 ) # Set cost to None if 0 if prov_stats["approx_cost"] == 0: prov_stats["approx_cost"] = None # Calculate global cache percentages if include_global and provider in global_providers: gp = global_providers[provider] global_total = ( gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"] ) if global_total > 0: gp["tokens"]["input_cache_pct"] = round( gp["tokens"]["input_cached"] / global_total * 100, 1 ) if gp["approx_cost"] == 0: gp["approx_cost"] = None prov_stats["global"] = gp # Build summary (current period) total_creds = sum(p["credential_count"] for p in providers.values()) active_creds = sum(p["active_count"] for p in providers.values()) exhausted_creds = sum(p["exhausted_count"] for p in providers.values()) total_requests = sum(p["total_requests"] for p in providers.values()) total_input_cached = sum( p["tokens"]["input_cached"] for p in providers.values() ) total_input_uncached = sum( p["tokens"]["input_uncached"] for p in providers.values() ) total_output = sum(p["tokens"]["output"] for p in providers.values()) total_cost = sum(p["approx_cost"] or 0 for p in providers.values()) total_input = total_input_cached + total_input_uncached input_cache_pct = ( round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0 ) result = { "providers": providers, "summary": { "total_providers": len(providers), "total_credentials": total_creds, "active_credentials": active_creds, "exhausted_credentials": exhausted_creds, "total_requests": total_requests, "tokens": { "input_cached": total_input_cached, "input_uncached": total_input_uncached, "input_cache_pct": input_cache_pct, "output": total_output, }, "approx_total_cost": total_cost if total_cost > 0 else None, }, "data_source": "cache", "timestamp": now_ts, } # Build global summary if include_global: global_total_requests = sum( gp["total_requests"] for gp in global_providers.values() ) global_total_input_cached = sum( gp["tokens"]["input_cached"] for gp in global_providers.values() ) global_total_input_uncached = sum( gp["tokens"]["input_uncached"] for gp in global_providers.values() ) global_total_output = sum( gp["tokens"]["output"] for gp in global_providers.values() ) global_total_cost = sum( gp["approx_cost"] or 0 for gp in global_providers.values() ) global_total_input = global_total_input_cached + global_total_input_uncached global_input_cache_pct = ( round(global_total_input_cached / global_total_input * 100, 1) if global_total_input > 0 else 0 ) result["global_summary"] = { "total_providers": len(global_providers), "total_credentials": total_creds, "total_requests": global_total_requests, "tokens": { "input_cached": global_total_input_cached, "input_uncached": global_total_input_uncached, "input_cache_pct": global_input_cache_pct, "output": global_total_output, }, "approx_total_cost": global_total_cost if global_total_cost > 0 else None, } return result async def reload_from_disk(self) -> None: """ Force reload usage data from disk. Useful when another process may have updated the file. """ async with self._init_lock: self._initialized.clear() await self._load_usage() await self._reset_daily_stats_if_needed() self._initialized.set()