llm-api-proxy / src /rotator_library /usage_manager.py
Mirrowel
refactor(quota-viewer): 🔨 enhance credential sorting and cooldown display
1af1879
import json
import os
import time
import logging
import asyncio
import random
from datetime import date, datetime, timezone, time as dt_time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import aiofiles
import litellm
from .error_handler import ClassifiedError, NoAvailableKeysError, mask_credential
from .providers import PROVIDER_PLUGINS
from .utils.resilient_io import ResilientStateWriter
from .utils.paths import get_data_file
lib_logger = logging.getLogger("rotator_library")
lib_logger.propagate = False
if not lib_logger.handlers:
lib_logger.addHandler(logging.NullHandler())
class UsageManager:
"""
Manages usage statistics and cooldowns for API keys with asyncio-safe locking,
asynchronous file I/O, lazy-loading mechanism, and weighted random credential rotation.
The credential rotation strategy can be configured via the `rotation_tolerance` parameter:
- **tolerance = 0.0**: Deterministic least-used selection. The credential with
the lowest usage count is always selected. This provides predictable, perfectly balanced
load distribution but may be vulnerable to fingerprinting.
- **tolerance = 2.0 - 4.0 (default, recommended)**: Balanced weighted randomness. Credentials are selected
randomly with weights biased toward less-used ones. Credentials within 2 uses of the
maximum can still be selected with reasonable probability. This provides security through
unpredictability while maintaining good load balance.
- **tolerance = 5.0+**: High randomness. Even heavily-used credentials have significant
selection probability. Useful for stress testing or maximum unpredictability, but may
result in less balanced load distribution.
The weight formula is: `weight = (max_usage - credential_usage) + tolerance + 1`
This ensures lower-usage credentials are preferred while tolerance controls how much
randomness is introduced into the selection process.
Additionally, providers can specify a rotation mode:
- "balanced" (default): Rotate credentials to distribute load evenly
- "sequential": Use one credential until exhausted (preserves caching)
"""
def __init__(
self,
file_path: Optional[Union[str, Path]] = None,
daily_reset_time_utc: Optional[str] = "03:00",
rotation_tolerance: float = 0.0,
provider_rotation_modes: Optional[Dict[str, str]] = None,
provider_plugins: Optional[Dict[str, Any]] = None,
priority_multipliers: Optional[Dict[str, Dict[int, int]]] = None,
priority_multipliers_by_mode: Optional[
Dict[str, Dict[str, Dict[int, int]]]
] = None,
sequential_fallback_multipliers: Optional[Dict[str, int]] = None,
):
"""
Initialize the UsageManager.
Args:
file_path: Path to the usage data JSON file. If None, uses get_data_file("key_usage.json").
Can be absolute Path, relative Path, or string.
daily_reset_time_utc: Time in UTC when daily stats should reset (HH:MM format)
rotation_tolerance: Tolerance for weighted random credential rotation.
- 0.0: Deterministic, least-used credential always selected
- tolerance = 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max
- 5.0+: High randomness, more unpredictable selection patterns
provider_rotation_modes: Dict mapping provider names to rotation modes.
- "balanced": Rotate credentials to distribute load evenly (default)
- "sequential": Use one credential until exhausted (preserves caching)
provider_plugins: Dict mapping provider names to provider plugin instances.
Used for per-provider usage reset configuration (window durations, field names).
priority_multipliers: Dict mapping provider -> priority -> multiplier.
Universal multipliers that apply regardless of rotation mode.
Example: {"antigravity": {1: 5, 2: 3}}
priority_multipliers_by_mode: Dict mapping provider -> mode -> priority -> multiplier.
Mode-specific overrides. Example: {"antigravity": {"balanced": {3: 1}}}
sequential_fallback_multipliers: Dict mapping provider -> fallback multiplier.
Used in sequential mode when priority not in priority_multipliers.
Example: {"antigravity": 2}
"""
# Resolve file_path - use default if not provided
if file_path is None:
self.file_path = str(get_data_file("key_usage.json"))
elif isinstance(file_path, Path):
self.file_path = str(file_path)
else:
# String path - could be relative or absolute
self.file_path = file_path
self.rotation_tolerance = rotation_tolerance
self.provider_rotation_modes = provider_rotation_modes or {}
self.provider_plugins = provider_plugins or PROVIDER_PLUGINS
self.priority_multipliers = priority_multipliers or {}
self.priority_multipliers_by_mode = priority_multipliers_by_mode or {}
self.sequential_fallback_multipliers = sequential_fallback_multipliers or {}
self._provider_instances: Dict[str, Any] = {} # Cache for provider instances
self.key_states: Dict[str, Dict[str, Any]] = {}
self._data_lock = asyncio.Lock()
self._usage_data: Optional[Dict] = None
self._initialized = asyncio.Event()
self._init_lock = asyncio.Lock()
self._timeout_lock = asyncio.Lock()
self._claimed_on_timeout: Set[str] = set()
# Resilient writer for usage data persistence
self._state_writer = ResilientStateWriter(file_path, lib_logger)
if daily_reset_time_utc:
hour, minute = map(int, daily_reset_time_utc.split(":"))
self.daily_reset_time_utc = dt_time(
hour=hour, minute=minute, tzinfo=timezone.utc
)
else:
self.daily_reset_time_utc = None
def _get_rotation_mode(self, provider: str) -> str:
"""
Get the rotation mode for a provider.
Args:
provider: Provider name (e.g., "antigravity", "gemini_cli")
Returns:
"balanced" or "sequential"
"""
return self.provider_rotation_modes.get(provider, "balanced")
def _get_priority_multiplier(
self, provider: str, priority: int, rotation_mode: str
) -> int:
"""
Get the concurrency multiplier for a provider/priority/mode combination.
Lookup order:
1. Mode-specific tier override: priority_multipliers_by_mode[provider][mode][priority]
2. Universal tier multiplier: priority_multipliers[provider][priority]
3. Sequential fallback (if mode is sequential): sequential_fallback_multipliers[provider]
4. Global default: 1 (no multiplier effect)
Args:
provider: Provider name (e.g., "antigravity")
priority: Priority level (1 = highest priority)
rotation_mode: Current rotation mode ("sequential" or "balanced")
Returns:
Multiplier value
"""
provider_lower = provider.lower()
# 1. Check mode-specific override
if provider_lower in self.priority_multipliers_by_mode:
mode_multipliers = self.priority_multipliers_by_mode[provider_lower]
if rotation_mode in mode_multipliers:
if priority in mode_multipliers[rotation_mode]:
return mode_multipliers[rotation_mode][priority]
# 2. Check universal tier multiplier
if provider_lower in self.priority_multipliers:
if priority in self.priority_multipliers[provider_lower]:
return self.priority_multipliers[provider_lower][priority]
# 3. Sequential fallback (only for sequential mode)
if rotation_mode == "sequential":
if provider_lower in self.sequential_fallback_multipliers:
return self.sequential_fallback_multipliers[provider_lower]
# 4. Global default
return 1
def _get_provider_from_credential(self, credential: str) -> Optional[str]:
"""
Extract provider name from credential path or identifier.
Supports multiple credential formats:
- OAuth: "oauth_creds/antigravity_oauth_15.json" -> "antigravity"
- OAuth: "C:\\...\\oauth_creds\\gemini_cli_oauth_1.json" -> "gemini_cli"
- OAuth filename only: "antigravity_oauth_1.json" -> "antigravity"
- API key style: stored with provider prefix metadata
Args:
credential: The credential identifier (path or key)
Returns:
Provider name string or None if cannot be determined
"""
import re
# Normalize path separators
normalized = credential.replace("\\", "/")
# Pattern: path ending with {provider}_oauth_{number}.json
match = re.search(r"/([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE)
if match:
return match.group(1).lower()
# Pattern: oauth_creds/{provider}_...
match = re.search(r"oauth_creds/([a-z_]+)_", normalized, re.IGNORECASE)
if match:
return match.group(1).lower()
# Pattern: filename only {provider}_oauth_{number}.json (no path)
match = re.match(r"([a-z_]+)_oauth_\d+\.json$", normalized, re.IGNORECASE)
if match:
return match.group(1).lower()
return None
def _get_provider_instance(self, provider: str) -> Optional[Any]:
"""
Get or create a provider plugin instance.
Args:
provider: The provider name
Returns:
Provider plugin instance or None
"""
if not provider:
return None
plugin_class = self.provider_plugins.get(provider)
if not plugin_class:
return None
# Get or create provider instance from cache
if provider not in self._provider_instances:
# Instantiate the plugin if it's a class, or use it directly if already an instance
if isinstance(plugin_class, type):
self._provider_instances[provider] = plugin_class()
else:
self._provider_instances[provider] = plugin_class
return self._provider_instances[provider]
def _get_usage_reset_config(self, credential: str) -> Optional[Dict[str, Any]]:
"""
Get the usage reset configuration for a credential from its provider plugin.
Args:
credential: The credential identifier
Returns:
Configuration dict with window_seconds, field_name, etc.
or None to use default daily reset.
"""
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(plugin_instance, "get_usage_reset_config"):
return plugin_instance.get_usage_reset_config(credential)
return None
def _get_reset_mode(self, credential: str) -> str:
"""
Get the reset mode for a credential: 'credential' or 'per_model'.
Args:
credential: The credential identifier
Returns:
"per_model" or "credential" (default)
"""
config = self._get_usage_reset_config(credential)
return config.get("mode", "credential") if config else "credential"
def _get_model_quota_group(self, credential: str, model: str) -> Optional[str]:
"""
Get the quota group for a model, if the provider defines one.
Args:
credential: The credential identifier
model: Model name (with or without provider prefix)
Returns:
Group name (e.g., "claude") or None if not grouped
"""
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(plugin_instance, "get_model_quota_group"):
return plugin_instance.get_model_quota_group(model)
return None
def _get_grouped_models(self, credential: str, group: str) -> List[str]:
"""
Get all model names in a quota group (with provider prefix).
Args:
credential: The credential identifier
group: Group name (e.g., "claude")
Returns:
List of full model names (e.g., ["antigravity/claude-opus-4-5", ...])
"""
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(plugin_instance, "get_models_in_quota_group"):
models = plugin_instance.get_models_in_quota_group(group)
# Add provider prefix
return [f"{provider}/{m}" for m in models]
return []
def _get_model_usage_weight(self, credential: str, model: str) -> int:
"""
Get the usage weight for a model when calculating grouped usage.
Args:
credential: The credential identifier
model: Model name (with or without provider prefix)
Returns:
Weight multiplier (default 1 if not configured)
"""
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(plugin_instance, "get_model_usage_weight"):
return plugin_instance.get_model_usage_weight(model)
return 1
# Providers where request_count should be used for credential selection
# instead of success_count (because failed requests also consume quota)
_REQUEST_COUNT_PROVIDERS = {"antigravity"}
def _get_grouped_usage_count(self, key: str, model: str) -> int:
"""
Get usage count for credential selection, considering quota groups.
For providers in _REQUEST_COUNT_PROVIDERS (e.g., antigravity), uses
request_count instead of success_count since failed requests also
consume quota.
If the model belongs to a quota group, the request_count is already
synced across all models in the group (by record_success/record_failure),
so we just read from the requested model directly.
Args:
key: Credential identifier
model: Model name (with provider prefix, e.g., "antigravity/claude-sonnet-4-5")
Returns:
Usage count for the model (synced across group if applicable)
"""
# Determine usage field based on provider
# Some providers (antigravity) count failed requests against quota
provider = self._get_provider_from_credential(key)
usage_field = (
"request_count"
if provider in self._REQUEST_COUNT_PROVIDERS
else "success_count"
)
# For providers with synced quota groups (antigravity), request_count
# is already synced across all models in the group, so just read directly.
# For other providers, we still need to sum success_count across group.
if provider in self._REQUEST_COUNT_PROVIDERS:
# request_count is synced - just read the model's value
return self._get_usage_count(key, model, usage_field)
# For non-synced providers, check if model is in a quota group and sum
group = self._get_model_quota_group(key, model)
if group:
# Get all models in the group
grouped_models = self._get_grouped_models(key, group)
# Sum weighted usage across all models in the group
total_weighted_usage = 0
for grouped_model in grouped_models:
usage = self._get_usage_count(key, grouped_model, usage_field)
weight = self._get_model_usage_weight(key, grouped_model)
total_weighted_usage += usage * weight
return total_weighted_usage
# Not grouped - return individual model usage (no weight applied)
return self._get_usage_count(key, model, usage_field)
def _get_quota_display(self, key: str, model: str) -> str:
"""
Get a formatted quota display string for logging.
For antigravity (providers in _REQUEST_COUNT_PROVIDERS), returns:
"quota: 170/250 [32%]" format
For other providers, returns:
"usage: 170" format (no max available)
Args:
key: Credential identifier
model: Model name (with provider prefix)
Returns:
Formatted string for logging
"""
provider = self._get_provider_from_credential(key)
if provider not in self._REQUEST_COUNT_PROVIDERS:
# Non-antigravity: just show usage count
usage = self._get_usage_count(key, model, "success_count")
return f"usage: {usage}"
# Antigravity: show quota display with remaining percentage
if self._usage_data is None:
return "quota: 0/? [100%]"
key_data = self._usage_data.get(key, {})
model_data = key_data.get("models", {}).get(model, {})
request_count = model_data.get("request_count", 0)
max_requests = model_data.get("quota_max_requests")
if max_requests:
remaining = max_requests - request_count
remaining_pct = (
int((remaining / max_requests) * 100) if max_requests > 0 else 0
)
return f"quota: {request_count}/{max_requests} [{remaining_pct}%]"
else:
return f"quota: {request_count}"
def _get_usage_field_name(self, credential: str) -> str:
"""
Get the usage tracking field name for a credential.
Returns the provider-specific field name if configured,
otherwise falls back to "daily".
Args:
credential: The credential identifier
Returns:
Field name string (e.g., "5h_window", "weekly", "daily")
"""
config = self._get_usage_reset_config(credential)
if config and "field_name" in config:
return config["field_name"]
# Check provider default
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(plugin_instance, "get_default_usage_field_name"):
return plugin_instance.get_default_usage_field_name()
return "daily"
def _get_usage_count(
self, key: str, model: str, field: str = "success_count"
) -> int:
"""
Get the current usage count for a model from the appropriate usage structure.
Supports both:
- New per-model structure: {"models": {"model_name": {"success_count": N, ...}}}
- Legacy structure: {"daily": {"models": {"model_name": {"success_count": N, ...}}}}
Args:
key: Credential identifier
model: Model name
field: The field to read for usage count (default: "success_count").
Use "request_count" for providers where failed requests also
consume quota (e.g., antigravity).
Returns:
Usage count for the model in the current window/period
"""
if self._usage_data is None:
return 0
key_data = self._usage_data.get(key, {})
reset_mode = self._get_reset_mode(key)
if reset_mode == "per_model":
# New per-model structure: key_data["models"][model][field]
return key_data.get("models", {}).get(model, {}).get(field, 0)
else:
# Legacy structure: key_data["daily"]["models"][model][field]
return (
key_data.get("daily", {}).get("models", {}).get(model, {}).get(field, 0)
)
# =========================================================================
# TIMESTAMP FORMATTING HELPERS
# =========================================================================
def _format_timestamp_local(self, ts: Optional[float]) -> Optional[str]:
"""
Format Unix timestamp as local time string with timezone offset.
Args:
ts: Unix timestamp or None
Returns:
Formatted string like "2025-12-07 14:30:17 +0100" or None
"""
if ts is None:
return None
try:
dt = datetime.fromtimestamp(ts).astimezone() # Local timezone
# Use UTC offset for conciseness (works on all platforms)
return dt.strftime("%Y-%m-%d %H:%M:%S %z")
except (OSError, ValueError, OverflowError):
return None
def _add_readable_timestamps(self, data: Dict) -> Dict:
"""
Add human-readable timestamp fields to usage data before saving.
Adds 'window_started' and 'quota_resets' fields derived from
Unix timestamps for easier debugging and monitoring.
Args:
data: The usage data dict to enhance
Returns:
The same dict with readable timestamp fields added
"""
for key, key_data in data.items():
# Handle per-model structure
models = key_data.get("models", {})
for model_name, model_stats in models.items():
if not isinstance(model_stats, dict):
continue
# Add readable window start time
window_start = model_stats.get("window_start_ts")
if window_start:
model_stats["window_started"] = self._format_timestamp_local(
window_start
)
elif "window_started" in model_stats:
del model_stats["window_started"]
# Add readable reset time
quota_reset = model_stats.get("quota_reset_ts")
if quota_reset:
model_stats["quota_resets"] = self._format_timestamp_local(
quota_reset
)
elif "quota_resets" in model_stats:
del model_stats["quota_resets"]
return data
def _sort_sequential(
self,
candidates: List[Tuple[str, int]],
credential_priorities: Optional[Dict[str, int]] = None,
) -> List[Tuple[str, int]]:
"""
Sort credentials for sequential mode with position retention.
Credentials maintain their position based on established usage patterns,
ensuring that actively-used credentials remain primary until exhausted.
Sorting order (within each sort key, lower value = higher priority):
1. Priority tier (lower number = higher priority)
2. Usage count (higher = more established in rotation, maintains position)
3. Last used timestamp (higher = more recent, tiebreaker for stickiness)
4. Credential ID (alphabetical, stable ordering)
Args:
candidates: List of (credential_id, usage_count) tuples
credential_priorities: Optional dict mapping credentials to priority levels
Returns:
Sorted list of candidates (same format as input)
"""
if not candidates:
return []
if len(candidates) == 1:
return candidates
def sort_key(item: Tuple[str, int]) -> Tuple[int, int, float, str]:
cred, usage_count = item
priority = (
credential_priorities.get(cred, 999) if credential_priorities else 999
)
last_used = (
self._usage_data.get(cred, {}).get("last_used_ts", 0)
if self._usage_data
else 0
)
return (
priority, # ASC: lower priority number = higher priority
-usage_count, # DESC: higher usage = more established
-last_used, # DESC: more recent = preferred for ties
cred, # ASC: stable alphabetical ordering
)
sorted_candidates = sorted(candidates, key=sort_key)
# Debug logging - show top 3 credentials in ordering
if lib_logger.isEnabledFor(logging.DEBUG):
order_info = [
f"{mask_credential(c)}(p={credential_priorities.get(c, 999) if credential_priorities else 'N/A'}, u={u})"
for c, u in sorted_candidates[:3]
]
lib_logger.debug(f"Sequential ordering: {' → '.join(order_info)}")
return sorted_candidates
async def _lazy_init(self):
"""Initializes the usage data by loading it from the file asynchronously."""
async with self._init_lock:
if not self._initialized.is_set():
await self._load_usage()
await self._reset_daily_stats_if_needed()
self._initialized.set()
async def _load_usage(self):
"""Loads usage data from the JSON file asynchronously with resilience."""
async with self._data_lock:
if not os.path.exists(self.file_path):
self._usage_data = {}
return
try:
async with aiofiles.open(self.file_path, "r") as f:
content = await f.read()
self._usage_data = json.loads(content) if content.strip() else {}
except FileNotFoundError:
# File deleted between exists check and open
self._usage_data = {}
except json.JSONDecodeError as e:
lib_logger.warning(
f"Corrupted usage file {self.file_path}: {e}. Starting fresh."
)
self._usage_data = {}
except (OSError, PermissionError, IOError) as e:
lib_logger.warning(
f"Cannot read usage file {self.file_path}: {e}. Using empty state."
)
self._usage_data = {}
async def _save_usage(self):
"""Saves the current usage data using the resilient state writer."""
if self._usage_data is None:
return
async with self._data_lock:
# Add human-readable timestamp fields before saving
self._add_readable_timestamps(self._usage_data)
# Hand off to resilient writer - handles retries and disk failures
self._state_writer.write(self._usage_data)
async def _get_usage_data_snapshot(self) -> Dict[str, Any]:
"""
Get a shallow copy of the current usage data.
Returns:
Copy of usage data dict (safe for reading without lock)
"""
await self._lazy_init()
async with self._data_lock:
return dict(self._usage_data) if self._usage_data else {}
async def get_available_credentials_for_model(
self, credentials: List[str], model: str
) -> List[str]:
"""
Get credentials that are not on cooldown for a specific model.
Filters out credentials where:
- key_cooldown_until > now (key-level cooldown)
- model_cooldowns[model] > now (model-specific cooldown, includes quota exhausted)
Args:
credentials: List of credential identifiers to check
model: Model name to check cooldowns for
Returns:
List of credentials that are available (not on cooldown) for this model
"""
await self._lazy_init()
now = time.time()
available = []
async with self._data_lock:
for key in credentials:
key_data = self._usage_data.get(key, {})
# Skip if key-level cooldown is active
if (key_data.get("key_cooldown_until") or 0) > now:
continue
# Skip if model-specific cooldown is active
if (key_data.get("model_cooldowns", {}).get(model) or 0) > now:
continue
available.append(key)
return available
async def _reset_daily_stats_if_needed(self):
"""
Checks if usage stats need to be reset for any key.
Supports three reset modes:
1. per_model: Each model has its own window, resets based on quota_reset_ts or fallback window
2. credential: One window per credential (legacy with custom window duration)
3. daily: Legacy daily reset at daily_reset_time_utc
"""
if self._usage_data is None:
return
now_utc = datetime.now(timezone.utc)
now_ts = time.time()
today_str = now_utc.date().isoformat()
needs_saving = False
for key, data in self._usage_data.items():
reset_config = self._get_usage_reset_config(key)
if reset_config:
reset_mode = reset_config.get("mode", "credential")
if reset_mode == "per_model":
# Per-model window reset
needs_saving |= await self._check_per_model_resets(
key, data, reset_config, now_ts
)
else:
# Credential-level window reset (legacy)
needs_saving |= await self._check_window_reset(
key, data, reset_config, now_ts
)
elif self.daily_reset_time_utc:
# Legacy daily reset
needs_saving |= await self._check_daily_reset(
key, data, now_utc, today_str, now_ts
)
if needs_saving:
await self._save_usage()
async def _check_per_model_resets(
self,
key: str,
data: Dict[str, Any],
reset_config: Dict[str, Any],
now_ts: float,
) -> bool:
"""
Check and perform per-model resets for a credential.
Each model resets independently based on:
1. quota_reset_ts (authoritative, from quota exhausted error) if set
2. window_start_ts + window_seconds (fallback) otherwise
Grouped models reset together - all models in a group must be ready.
Args:
key: Credential identifier
data: Usage data for this credential
reset_config: Provider's reset configuration
now_ts: Current timestamp
Returns:
True if data was modified and needs saving
"""
window_seconds = reset_config.get("window_seconds", 86400)
models_data = data.get("models", {})
if not models_data:
return False
modified = False
processed_groups = set()
for model, model_data in list(models_data.items()):
# Check if this model is in a quota group
group = self._get_model_quota_group(key, model)
if group:
if group in processed_groups:
continue # Already handled this group
# Check if entire group should reset
if self._should_group_reset(
key, group, models_data, window_seconds, now_ts
):
# Archive and reset all models in group
grouped_models = self._get_grouped_models(key, group)
archived_count = 0
for grouped_model in grouped_models:
if grouped_model in models_data:
gm_data = models_data[grouped_model]
self._archive_model_to_global(data, grouped_model, gm_data)
self._reset_model_data(gm_data)
archived_count += 1
if archived_count > 0:
lib_logger.info(
f"Reset model group '{group}' ({archived_count} models) for {mask_credential(key)}"
)
modified = True
processed_groups.add(group)
else:
# Ungrouped model - check individually
if self._should_model_reset(model_data, window_seconds, now_ts):
self._archive_model_to_global(data, model, model_data)
self._reset_model_data(model_data)
lib_logger.info(f"Reset model {model} for {mask_credential(key)}")
modified = True
# Preserve unexpired cooldowns
if modified:
self._preserve_unexpired_cooldowns(key, data, now_ts)
if "failures" in data:
data["failures"] = {}
return modified
def _should_model_reset(
self, model_data: Dict[str, Any], window_seconds: int, now_ts: float
) -> bool:
"""
Check if a single model should reset.
Returns True if:
- quota_reset_ts is set AND now >= quota_reset_ts, OR
- quota_reset_ts is NOT set AND now >= window_start_ts + window_seconds
"""
quota_reset = model_data.get("quota_reset_ts")
window_start = model_data.get("window_start_ts")
if quota_reset:
return now_ts >= quota_reset
elif window_start:
return now_ts >= window_start + window_seconds
return False
def _should_group_reset(
self,
key: str,
group: str,
models_data: Dict[str, Dict],
window_seconds: int,
now_ts: float,
) -> bool:
"""
Check if all models in a group should reset.
All models in the group must be ready to reset.
If any model has an active cooldown/window, the whole group waits.
"""
grouped_models = self._get_grouped_models(key, group)
# Track if any model in group has data
any_has_data = False
for grouped_model in grouped_models:
model_data = models_data.get(grouped_model, {})
if not model_data or (
model_data.get("window_start_ts") is None
and model_data.get("success_count", 0) == 0
):
continue # No stats for this model yet
any_has_data = True
if not self._should_model_reset(model_data, window_seconds, now_ts):
return False # At least one model not ready
return any_has_data
def _archive_model_to_global(
self, data: Dict[str, Any], model: str, model_data: Dict[str, Any]
) -> None:
"""Archive a single model's stats to global."""
global_data = data.setdefault("global", {"models": {}})
global_model = global_data["models"].setdefault(
model,
{
"success_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
global_model["success_count"] += model_data.get("success_count", 0)
global_model["prompt_tokens"] += model_data.get("prompt_tokens", 0)
global_model["completion_tokens"] += model_data.get("completion_tokens", 0)
global_model["approx_cost"] += model_data.get("approx_cost", 0.0)
def _reset_model_data(self, model_data: Dict[str, Any]) -> None:
"""Reset a model's window and stats."""
model_data["window_start_ts"] = None
model_data["quota_reset_ts"] = None
model_data["success_count"] = 0
model_data["failure_count"] = 0
model_data["request_count"] = 0
model_data["prompt_tokens"] = 0
model_data["completion_tokens"] = 0
model_data["approx_cost"] = 0.0
# Reset quota baseline fields only if they exist (Antigravity-specific)
# These are added by update_quota_baseline(), only called for Antigravity
if "baseline_remaining_fraction" in model_data:
model_data["baseline_remaining_fraction"] = None
model_data["baseline_fetched_at"] = None
model_data["requests_at_baseline"] = None
# Reset quota display but keep max_requests (it doesn't change between periods)
max_req = model_data.get("quota_max_requests")
if max_req:
model_data["quota_display"] = f"0/{max_req}"
async def _check_window_reset(
self,
key: str,
data: Dict[str, Any],
reset_config: Dict[str, Any],
now_ts: float,
) -> bool:
"""
Check and perform rolling window reset for a credential.
Args:
key: Credential identifier
data: Usage data for this credential
reset_config: Provider's reset configuration
now_ts: Current timestamp
Returns:
True if data was modified and needs saving
"""
window_seconds = reset_config.get("window_seconds", 86400) # Default 24h
field_name = reset_config.get("field_name", "window")
description = reset_config.get("description", "rolling window")
# Get current window data
window_data = data.get(field_name, {})
window_start = window_data.get("start_ts")
# No window started yet - nothing to reset
if window_start is None:
return False
# Check if window has expired
window_end = window_start + window_seconds
if now_ts < window_end:
# Window still active
return False
# Window expired - perform reset
hours_elapsed = (now_ts - window_start) / 3600
lib_logger.info(
f"Resetting {field_name} for {mask_credential(key)} - "
f"{description} expired after {hours_elapsed:.1f}h"
)
# Archive to global
self._archive_to_global(data, window_data)
# Preserve unexpired cooldowns
self._preserve_unexpired_cooldowns(key, data, now_ts)
# Reset window stats (but don't start new window until first request)
data[field_name] = {"start_ts": None, "models": {}}
# Reset consecutive failures
if "failures" in data:
data["failures"] = {}
return True
async def _check_daily_reset(
self,
key: str,
data: Dict[str, Any],
now_utc: datetime,
today_str: str,
now_ts: float,
) -> bool:
"""
Check and perform legacy daily reset for a credential.
Args:
key: Credential identifier
data: Usage data for this credential
now_utc: Current datetime in UTC
today_str: Today's date as ISO string
now_ts: Current timestamp
Returns:
True if data was modified and needs saving
"""
last_reset_str = data.get("last_daily_reset", "")
if last_reset_str == today_str:
return False
last_reset_dt = None
if last_reset_str:
try:
last_reset_dt = datetime.fromisoformat(last_reset_str).replace(
tzinfo=timezone.utc
)
except ValueError:
pass
# Determine the reset threshold for today
reset_threshold_today = datetime.combine(
now_utc.date(), self.daily_reset_time_utc
)
if not (
last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc
):
return False
lib_logger.debug(f"Performing daily reset for key {mask_credential(key)}")
# Preserve unexpired cooldowns
self._preserve_unexpired_cooldowns(key, data, now_ts)
# Reset consecutive failures
if "failures" in data:
data["failures"] = {}
# Archive daily stats to global
daily_data = data.get("daily", {})
if daily_data:
self._archive_to_global(data, daily_data)
# Reset daily stats
data["daily"] = {"date": today_str, "models": {}}
data["last_daily_reset"] = today_str
return True
def _archive_to_global(
self, data: Dict[str, Any], source_data: Dict[str, Any]
) -> None:
"""
Archive usage stats from a source field (daily/window) to global.
Args:
data: The credential's usage data
source_data: The source field data to archive (has "models" key)
"""
global_data = data.setdefault("global", {"models": {}})
for model, stats in source_data.get("models", {}).items():
global_model_stats = global_data["models"].setdefault(
model,
{
"success_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
global_model_stats["success_count"] += stats.get("success_count", 0)
global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0)
global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0)
global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0)
def _preserve_unexpired_cooldowns(
self, key: str, data: Dict[str, Any], now_ts: float
) -> None:
"""
Preserve unexpired cooldowns during reset (important for long quota cooldowns).
Args:
key: Credential identifier (for logging)
data: The credential's usage data
now_ts: Current timestamp
"""
# Preserve unexpired model cooldowns
if "model_cooldowns" in data:
active_cooldowns = {
model: end_time
for model, end_time in data["model_cooldowns"].items()
if end_time > now_ts
}
if active_cooldowns:
max_remaining = max(
end_time - now_ts for end_time in active_cooldowns.values()
)
hours_remaining = max_remaining / 3600
lib_logger.info(
f"Preserving {len(active_cooldowns)} active cooldown(s) "
f"for key {mask_credential(key)} during reset "
f"(longest: {hours_remaining:.1f}h remaining)"
)
data["model_cooldowns"] = active_cooldowns
else:
data["model_cooldowns"] = {}
# Preserve unexpired key-level cooldown
if data.get("key_cooldown_until"):
if data["key_cooldown_until"] <= now_ts:
data["key_cooldown_until"] = None
else:
hours_remaining = (data["key_cooldown_until"] - now_ts) / 3600
lib_logger.info(
f"Preserving key-level cooldown for {mask_credential(key)} "
f"during reset ({hours_remaining:.1f}h remaining)"
)
else:
data["key_cooldown_until"] = None
def _initialize_key_states(self, keys: List[str]):
"""Initializes state tracking for all provided keys if not already present."""
for key in keys:
if key not in self.key_states:
self.key_states[key] = {
"lock": asyncio.Lock(),
"condition": asyncio.Condition(),
"models_in_use": {}, # Dict[model_name, concurrent_count]
}
def _select_weighted_random(self, candidates: List[tuple], tolerance: float) -> str:
"""
Selects a credential using weighted random selection based on usage counts.
Args:
candidates: List of (credential_id, usage_count) tuples
tolerance: Tolerance value for weight calculation
Returns:
Selected credential ID
Formula:
weight = (max_usage - credential_usage) + tolerance + 1
This formula ensures:
- Lower usage = higher weight = higher selection probability
- Tolerance adds variability: higher tolerance means more randomness
- The +1 ensures all credentials have at least some chance of selection
"""
if not candidates:
raise ValueError("Cannot select from empty candidate list")
if len(candidates) == 1:
return candidates[0][0]
# Extract usage counts
usage_counts = [usage for _, usage in candidates]
max_usage = max(usage_counts)
# Calculate weights using the formula: (max - current) + tolerance + 1
weights = []
for credential, usage in candidates:
weight = (max_usage - usage) + tolerance + 1
weights.append(weight)
# Log weight distribution for debugging
if lib_logger.isEnabledFor(logging.DEBUG):
total_weight = sum(weights)
weight_info = ", ".join(
f"{mask_credential(cred)}: w={w:.1f} ({w / total_weight * 100:.1f}%)"
for (cred, _), w in zip(candidates, weights)
)
# lib_logger.debug(f"Weighted selection candidates: {weight_info}")
# Random selection with weights
selected_credential = random.choices(
[cred for cred, _ in candidates], weights=weights, k=1
)[0]
return selected_credential
async def acquire_key(
self,
available_keys: List[str],
model: str,
deadline: float,
max_concurrent: int = 1,
credential_priorities: Optional[Dict[str, int]] = None,
credential_tier_names: Optional[Dict[str, str]] = None,
) -> str:
"""
Acquires the best available key using a tiered, model-aware locking strategy,
respecting a global deadline and credential priorities.
Priority Logic:
- Groups credentials by priority level (1=highest, 2=lower, etc.)
- Always tries highest priority (lowest number) first
- Within same priority, sorts by usage count (load balancing)
- Only moves to next priority if all higher-priority keys exhausted/busy
Args:
available_keys: List of credential identifiers to choose from
model: Model name being requested
deadline: Timestamp after which to stop trying
max_concurrent: Maximum concurrent requests allowed per credential
credential_priorities: Optional dict mapping credentials to priority levels (1=highest)
credential_tier_names: Optional dict mapping credentials to tier names (for logging)
Returns:
Selected credential identifier
Raises:
NoAvailableKeysError: If no key could be acquired within the deadline
"""
await self._lazy_init()
await self._reset_daily_stats_if_needed()
self._initialize_key_states(available_keys)
# This loop continues as long as the global deadline has not been met.
while time.time() < deadline:
now = time.time()
# Group credentials by priority level (if priorities provided)
if credential_priorities:
# Group keys by priority level
priority_groups = {}
async with self._data_lock:
for key in available_keys:
key_data = self._usage_data.get(key, {})
# Skip keys on cooldown
if (key_data.get("key_cooldown_until") or 0) > now or (
key_data.get("model_cooldowns", {}).get(model) or 0
) > now:
continue
# Get priority for this key (default to 999 if not specified)
priority = credential_priorities.get(key, 999)
# Get usage count for load balancing within priority groups
# Uses grouped usage if model is in a quota group
usage_count = self._get_grouped_usage_count(key, model)
# Group by priority
if priority not in priority_groups:
priority_groups[priority] = []
priority_groups[priority].append((key, usage_count))
# Try priority groups in order (1, 2, 3, ...)
sorted_priorities = sorted(priority_groups.keys())
for priority_level in sorted_priorities:
keys_in_priority = priority_groups[priority_level]
# Determine selection method based on provider's rotation mode
provider = model.split("/")[0] if "/" in model else ""
rotation_mode = self._get_rotation_mode(provider)
# Calculate effective concurrency based on priority tier
multiplier = self._get_priority_multiplier(
provider, priority_level, rotation_mode
)
effective_max_concurrent = max_concurrent * multiplier
# Within each priority group, use existing tier1/tier2 logic
tier1_keys, tier2_keys = [], []
for key, usage_count in keys_in_priority:
key_state = self.key_states[key]
# Tier 1: Completely idle keys (preferred)
if not key_state["models_in_use"]:
tier1_keys.append((key, usage_count))
# Tier 2: Keys that can accept more concurrent requests
elif (
key_state["models_in_use"].get(model, 0)
< effective_max_concurrent
):
tier2_keys.append((key, usage_count))
if rotation_mode == "sequential":
# Sequential mode: sort credentials by priority, usage, recency
# Keep all candidates in sorted order (no filtering to single key)
selection_method = "sequential"
if tier1_keys:
tier1_keys = self._sort_sequential(
tier1_keys, credential_priorities
)
if tier2_keys:
tier2_keys = self._sort_sequential(
tier2_keys, credential_priorities
)
elif self.rotation_tolerance > 0:
# Balanced mode with weighted randomness
selection_method = "weighted-random"
if tier1_keys:
selected_key = self._select_weighted_random(
tier1_keys, self.rotation_tolerance
)
tier1_keys = [
(k, u) for k, u in tier1_keys if k == selected_key
]
if tier2_keys:
selected_key = self._select_weighted_random(
tier2_keys, self.rotation_tolerance
)
tier2_keys = [
(k, u) for k, u in tier2_keys if k == selected_key
]
else:
# Deterministic: sort by usage within each tier
selection_method = "least-used"
tier1_keys.sort(key=lambda x: x[1])
tier2_keys.sort(key=lambda x: x[1])
# Try to acquire from Tier 1 first
for key, usage in tier1_keys:
state = self.key_states[key]
async with state["lock"]:
if not state["models_in_use"]:
state["models_in_use"][model] = 1
tier_name = (
credential_tier_names.get(key, "unknown")
if credential_tier_names
else "unknown"
)
quota_display = self._get_quota_display(key, model)
lib_logger.info(
f"Acquired key {mask_credential(key)} for model {model} "
f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, {quota_display})"
)
return key
# Then try Tier 2
for key, usage in tier2_keys:
state = self.key_states[key]
async with state["lock"]:
current_count = state["models_in_use"].get(model, 0)
if current_count < effective_max_concurrent:
state["models_in_use"][model] = current_count + 1
tier_name = (
credential_tier_names.get(key, "unknown")
if credential_tier_names
else "unknown"
)
quota_display = self._get_quota_display(key, model)
lib_logger.info(
f"Acquired key {mask_credential(key)} for model {model} "
f"(tier: {tier_name}, priority: {priority_level}, selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})"
)
return key
# If we get here, all priority groups were exhausted but keys might become available
# Collect all keys across all priorities for waiting
all_potential_keys = []
for keys_list in priority_groups.values():
all_potential_keys.extend(keys_list)
if not all_potential_keys:
lib_logger.warning(
"No keys are eligible (all on cooldown or filtered out). Waiting before re-evaluating."
)
await asyncio.sleep(1)
continue
# Wait for the highest priority key with lowest usage
best_priority = min(priority_groups.keys())
best_priority_keys = priority_groups[best_priority]
best_wait_key = min(best_priority_keys, key=lambda x: x[1])[0]
wait_condition = self.key_states[best_wait_key]["condition"]
lib_logger.info(
f"All Priority-{best_priority} keys are busy. Waiting for highest priority credential to become available..."
)
else:
# Original logic when no priorities specified
# Determine selection method based on provider's rotation mode
provider = model.split("/")[0] if "/" in model else ""
rotation_mode = self._get_rotation_mode(provider)
# Calculate effective concurrency for default priority (999)
# When no priorities are specified, all credentials get default priority
default_priority = 999
multiplier = self._get_priority_multiplier(
provider, default_priority, rotation_mode
)
effective_max_concurrent = max_concurrent * multiplier
tier1_keys, tier2_keys = [], []
# First, filter the list of available keys to exclude any on cooldown.
async with self._data_lock:
for key in available_keys:
key_data = self._usage_data.get(key, {})
if (key_data.get("key_cooldown_until") or 0) > now or (
key_data.get("model_cooldowns", {}).get(model) or 0
) > now:
continue
# Prioritize keys based on their current usage to ensure load balancing.
# Uses grouped usage if model is in a quota group
usage_count = self._get_grouped_usage_count(key, model)
key_state = self.key_states[key]
# Tier 1: Completely idle keys (preferred).
if not key_state["models_in_use"]:
tier1_keys.append((key, usage_count))
# Tier 2: Keys that can accept more concurrent requests for this model.
elif (
key_state["models_in_use"].get(model, 0)
< effective_max_concurrent
):
tier2_keys.append((key, usage_count))
if rotation_mode == "sequential":
# Sequential mode: sort credentials by priority, usage, recency
# Keep all candidates in sorted order (no filtering to single key)
selection_method = "sequential"
if tier1_keys:
tier1_keys = self._sort_sequential(
tier1_keys, credential_priorities
)
if tier2_keys:
tier2_keys = self._sort_sequential(
tier2_keys, credential_priorities
)
elif self.rotation_tolerance > 0:
# Balanced mode with weighted randomness
selection_method = "weighted-random"
if tier1_keys:
selected_key = self._select_weighted_random(
tier1_keys, self.rotation_tolerance
)
tier1_keys = [
(k, u) for k, u in tier1_keys if k == selected_key
]
if tier2_keys:
selected_key = self._select_weighted_random(
tier2_keys, self.rotation_tolerance
)
tier2_keys = [
(k, u) for k, u in tier2_keys if k == selected_key
]
else:
# Deterministic: sort by usage within each tier
selection_method = "least-used"
tier1_keys.sort(key=lambda x: x[1])
tier2_keys.sort(key=lambda x: x[1])
# Attempt to acquire a key from Tier 1 first.
for key, usage in tier1_keys:
state = self.key_states[key]
async with state["lock"]:
if not state["models_in_use"]:
state["models_in_use"][model] = 1
tier_name = (
credential_tier_names.get(key)
if credential_tier_names
else None
)
tier_info = f"tier: {tier_name}, " if tier_name else ""
quota_display = self._get_quota_display(key, model)
lib_logger.info(
f"Acquired key {mask_credential(key)} for model {model} "
f"({tier_info}selection: {selection_method}, {quota_display})"
)
return key
# If no Tier 1 keys are available, try Tier 2.
for key, usage in tier2_keys:
state = self.key_states[key]
async with state["lock"]:
current_count = state["models_in_use"].get(model, 0)
if current_count < effective_max_concurrent:
state["models_in_use"][model] = current_count + 1
tier_name = (
credential_tier_names.get(key)
if credential_tier_names
else None
)
tier_info = f"tier: {tier_name}, " if tier_name else ""
quota_display = self._get_quota_display(key, model)
lib_logger.info(
f"Acquired key {mask_credential(key)} for model {model} "
f"({tier_info}selection: {selection_method}, concurrent: {state['models_in_use'][model]}/{effective_max_concurrent}, {quota_display})"
)
return key
# If all eligible keys are locked, wait for a key to be released.
lib_logger.info(
"All eligible keys are currently locked for this model. Waiting..."
)
all_potential_keys = tier1_keys + tier2_keys
if not all_potential_keys:
lib_logger.warning(
"No keys are eligible (all on cooldown). Waiting before re-evaluating."
)
await asyncio.sleep(1)
continue
# Wait on the condition of the key with the lowest current usage.
best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0]
wait_condition = self.key_states[best_wait_key]["condition"]
try:
async with wait_condition:
remaining_budget = deadline - time.time()
if remaining_budget <= 0:
break # Exit if the budget has already been exceeded.
# Wait for a notification, but no longer than the remaining budget or 1 second.
await asyncio.wait_for(
wait_condition.wait(), timeout=min(1, remaining_budget)
)
lib_logger.info("Notified that a key was released. Re-evaluating...")
except asyncio.TimeoutError:
# This is not an error, just a timeout for the wait. The main loop will re-evaluate.
lib_logger.info("Wait timed out. Re-evaluating for any available key.")
# If the loop exits, it means the deadline was exceeded.
raise NoAvailableKeysError(
f"Could not acquire a key for model {model} within the global time budget."
)
async def release_key(self, key: str, model: str):
"""Releases a key's lock for a specific model and notifies waiting tasks."""
if key not in self.key_states:
return
state = self.key_states[key]
async with state["lock"]:
if model in state["models_in_use"]:
state["models_in_use"][model] -= 1
remaining = state["models_in_use"][model]
if remaining <= 0:
del state["models_in_use"][model] # Clean up when count reaches 0
lib_logger.info(
f"Released credential {mask_credential(key)} from model {model} "
f"(remaining concurrent: {max(0, remaining)})"
)
else:
lib_logger.warning(
f"Attempted to release credential {mask_credential(key)} for model {model}, but it was not in use."
)
# Notify all tasks waiting on this key's condition
async with state["condition"]:
state["condition"].notify_all()
async def record_success(
self,
key: str,
model: str,
completion_response: Optional[litellm.ModelResponse] = None,
):
"""
Records a successful API call, resetting failure counters.
It safely handles cases where token usage data is not available.
Supports two modes based on provider configuration:
- per_model: Each model has its own window_start_ts and stats in key_data["models"]
- credential: Legacy mode with key_data["daily"]["models"]
"""
await self._lazy_init()
async with self._data_lock:
now_ts = time.time()
today_utc_str = datetime.now(timezone.utc).date().isoformat()
reset_config = self._get_usage_reset_config(key)
reset_mode = (
reset_config.get("mode", "credential") if reset_config else "credential"
)
if reset_mode == "per_model":
# New per-model structure
key_data = self._usage_data.setdefault(
key,
{
"models": {},
"global": {"models": {}},
"model_cooldowns": {},
"failures": {},
},
)
# Ensure models dict exists
if "models" not in key_data:
key_data["models"] = {}
# Get or create per-model data with window tracking
model_data = key_data["models"].setdefault(
model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
# Start window on first request for this model
if model_data.get("window_start_ts") is None:
model_data["window_start_ts"] = now_ts
# Set expected quota reset time from provider config
window_seconds = (
reset_config.get("window_seconds", 0) if reset_config else 0
)
if window_seconds > 0:
model_data["quota_reset_ts"] = now_ts + window_seconds
window_hours = window_seconds / 3600 if window_seconds else 0
lib_logger.info(
f"Started {window_hours:.1f}h window for model {model} on {mask_credential(key)}"
)
# Record stats
model_data["success_count"] += 1
model_data["request_count"] = model_data.get("request_count", 0) + 1
# Sync request_count across quota group (for providers with shared quota pools)
new_request_count = model_data["request_count"]
group = self._get_model_quota_group(key, model)
if group:
grouped_models = self._get_grouped_models(key, group)
for grouped_model in grouped_models:
if grouped_model != model:
other_model_data = key_data["models"].setdefault(
grouped_model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
other_model_data["request_count"] = new_request_count
# Also sync quota_max_requests if set
max_req = model_data.get("quota_max_requests")
if max_req:
other_model_data["quota_max_requests"] = max_req
other_model_data["quota_display"] = (
f"{new_request_count}/{max_req}"
)
# Update quota_display if max_requests is set (Antigravity-specific)
max_req = model_data.get("quota_max_requests")
if max_req:
model_data["quota_display"] = (
f"{model_data['request_count']}/{max_req}"
)
usage_data_ref = model_data # For token/cost recording below
else:
# Legacy credential-level structure
key_data = self._usage_data.setdefault(
key,
{
"daily": {"date": today_utc_str, "models": {}},
"global": {"models": {}},
"model_cooldowns": {},
"failures": {},
},
)
if "last_daily_reset" not in key_data:
key_data["last_daily_reset"] = today_utc_str
# Get or create model data in daily structure
usage_data_ref = key_data["daily"]["models"].setdefault(
model,
{
"success_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
usage_data_ref["success_count"] += 1
# Reset failures for this model
model_failures = key_data.setdefault("failures", {}).setdefault(model, {})
model_failures["consecutive_failures"] = 0
# Clear transient cooldown on success (but NOT quota_reset_ts)
if model in key_data.get("model_cooldowns", {}):
del key_data["model_cooldowns"][model]
# Record token and cost usage
if (
completion_response
and hasattr(completion_response, "usage")
and completion_response.usage
):
usage = completion_response.usage
usage_data_ref["prompt_tokens"] += usage.prompt_tokens
usage_data_ref["completion_tokens"] += getattr(
usage, "completion_tokens", 0
)
lib_logger.info(
f"Recorded usage from response object for key {mask_credential(key)}"
)
try:
provider_name = model.split("/")[0]
provider_instance = self._get_provider_instance(provider_name)
if provider_instance and getattr(
provider_instance, "skip_cost_calculation", False
):
lib_logger.debug(
f"Skipping cost calculation for provider '{provider_name}' (custom provider)."
)
else:
if isinstance(completion_response, litellm.EmbeddingResponse):
model_info = litellm.get_model_info(model)
input_cost = model_info.get("input_cost_per_token")
if input_cost:
cost = (
completion_response.usage.prompt_tokens * input_cost
)
else:
cost = None
else:
cost = litellm.completion_cost(
completion_response=completion_response, model=model
)
if cost is not None:
usage_data_ref["approx_cost"] += cost
except Exception as e:
lib_logger.warning(
f"Could not calculate cost for model {model}: {e}"
)
elif isinstance(completion_response, asyncio.Future) or hasattr(
completion_response, "__aiter__"
):
pass # Stream - usage recorded from chunks
else:
lib_logger.warning(
f"No usage data found in completion response for model {model}. Recording success without token count."
)
key_data["last_used_ts"] = now_ts
await self._save_usage()
async def record_failure(
self,
key: str,
model: str,
classified_error: ClassifiedError,
increment_consecutive_failures: bool = True,
):
"""Records a failure and applies cooldowns based on error type.
Distinguishes between:
- quota_exceeded: Long cooldown with exact reset time (from quota_reset_timestamp)
Sets quota_reset_ts on model (and group) - this becomes authoritative stats reset time
- rate_limit: Short transient cooldown (just wait and retry)
Only sets model_cooldowns - does NOT affect stats reset timing
Args:
key: The API key or credential identifier
model: The model name
classified_error: The classified error object
increment_consecutive_failures: Whether to increment the failure counter.
Set to False for provider-level errors that shouldn't count against the key.
"""
await self._lazy_init()
async with self._data_lock:
now_ts = time.time()
today_utc_str = datetime.now(timezone.utc).date().isoformat()
reset_config = self._get_usage_reset_config(key)
reset_mode = (
reset_config.get("mode", "credential") if reset_config else "credential"
)
# Initialize key data with appropriate structure
if reset_mode == "per_model":
key_data = self._usage_data.setdefault(
key,
{
"models": {},
"global": {"models": {}},
"model_cooldowns": {},
"failures": {},
},
)
else:
key_data = self._usage_data.setdefault(
key,
{
"daily": {"date": today_utc_str, "models": {}},
"global": {"models": {}},
"model_cooldowns": {},
"failures": {},
},
)
# Provider-level errors (transient issues) should not count against the key
provider_level_errors = {"server_error", "api_connection"}
# Determine if we should increment the failure counter
should_increment = (
increment_consecutive_failures
and classified_error.error_type not in provider_level_errors
)
# Calculate cooldown duration based on error type
cooldown_seconds = None
model_cooldowns = key_data.setdefault("model_cooldowns", {})
if classified_error.error_type == "quota_exceeded":
# Quota exhausted - use authoritative reset timestamp if available
quota_reset_ts = classified_error.quota_reset_timestamp
cooldown_seconds = classified_error.retry_after or 60
if quota_reset_ts and reset_mode == "per_model":
# Set quota_reset_ts on model - this becomes authoritative stats reset time
models_data = key_data.setdefault("models", {})
model_data = models_data.setdefault(
model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
model_data["quota_reset_ts"] = quota_reset_ts
# Track failure for quota estimation (request still consumes quota)
model_data["failure_count"] = model_data.get("failure_count", 0) + 1
model_data["request_count"] = model_data.get("request_count", 0) + 1
new_request_count = model_data["request_count"]
# Apply to all models in the same quota group
group = self._get_model_quota_group(key, model)
if group:
grouped_models = self._get_grouped_models(key, group)
for grouped_model in grouped_models:
group_model_data = models_data.setdefault(
grouped_model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
group_model_data["quota_reset_ts"] = quota_reset_ts
# Sync request_count across quota group
group_model_data["request_count"] = new_request_count
# Also sync quota_max_requests if set
max_req = model_data.get("quota_max_requests")
if max_req:
group_model_data["quota_max_requests"] = max_req
group_model_data["quota_display"] = (
f"{new_request_count}/{max_req}"
)
# Also set transient cooldown for selection logic
model_cooldowns[grouped_model] = quota_reset_ts
reset_dt = datetime.fromtimestamp(
quota_reset_ts, tz=timezone.utc
)
lib_logger.info(
f"Quota exhausted for group '{group}' ({len(grouped_models)} models) "
f"on {mask_credential(key)}. Resets at {reset_dt.isoformat()}"
)
else:
reset_dt = datetime.fromtimestamp(
quota_reset_ts, tz=timezone.utc
)
hours = (quota_reset_ts - now_ts) / 3600
lib_logger.info(
f"Quota exhausted for model {model} on {mask_credential(key)}. "
f"Resets at {reset_dt.isoformat()} ({hours:.1f}h)"
)
# Set transient cooldown for selection logic
model_cooldowns[model] = quota_reset_ts
else:
# No authoritative timestamp or legacy mode - just use retry_after
model_cooldowns[model] = now_ts + cooldown_seconds
hours = cooldown_seconds / 3600
lib_logger.info(
f"Quota exhausted on {mask_credential(key)} for model {model}. "
f"Cooldown: {cooldown_seconds}s ({hours:.1f}h)"
)
elif classified_error.error_type == "rate_limit":
# Transient rate limit - just set short cooldown (does NOT set quota_reset_ts)
cooldown_seconds = classified_error.retry_after or 60
model_cooldowns[model] = now_ts + cooldown_seconds
lib_logger.info(
f"Rate limit on {mask_credential(key)} for model {model}. "
f"Transient cooldown: {cooldown_seconds}s"
)
elif classified_error.error_type == "authentication":
# Apply a 5-minute key-level lockout for auth errors
key_data["key_cooldown_until"] = now_ts + 300
cooldown_seconds = 300
model_cooldowns[model] = now_ts + cooldown_seconds
lib_logger.warning(
f"Authentication error on key {mask_credential(key)}. Applying 5-minute key-level lockout."
)
# If we should increment failures, calculate escalating backoff
if should_increment:
failures_data = key_data.setdefault("failures", {})
model_failures = failures_data.setdefault(
model, {"consecutive_failures": 0}
)
model_failures["consecutive_failures"] += 1
count = model_failures["consecutive_failures"]
# If cooldown wasn't set by specific error type, use escalating backoff
if cooldown_seconds is None:
backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120}
cooldown_seconds = backoff_tiers.get(count, 7200)
model_cooldowns[model] = now_ts + cooldown_seconds
lib_logger.warning(
f"Failure #{count} for key {mask_credential(key)} with model {model}. "
f"Error type: {classified_error.error_type}, cooldown: {cooldown_seconds}s"
)
else:
# Provider-level errors: apply short cooldown but don't count against key
if cooldown_seconds is None:
cooldown_seconds = 30
model_cooldowns[model] = now_ts + cooldown_seconds
lib_logger.info(
f"Provider-level error ({classified_error.error_type}) for key {mask_credential(key)} "
f"with model {model}. NOT incrementing failures. Cooldown: {cooldown_seconds}s"
)
# Check for key-level lockout condition
await self._check_key_lockout(key, key_data)
# Track failure count for quota estimation (all failures consume quota)
# This is separate from consecutive_failures which is for backoff logic
if reset_mode == "per_model":
models_data = key_data.setdefault("models", {})
model_data = models_data.setdefault(
model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
# Only increment if not already incremented in quota_exceeded branch
if classified_error.error_type != "quota_exceeded":
model_data["failure_count"] = model_data.get("failure_count", 0) + 1
model_data["request_count"] = model_data.get("request_count", 0) + 1
# Sync request_count across quota group
new_request_count = model_data["request_count"]
group = self._get_model_quota_group(key, model)
if group:
grouped_models = self._get_grouped_models(key, group)
for grouped_model in grouped_models:
if grouped_model != model:
other_model_data = models_data.setdefault(
grouped_model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
other_model_data["request_count"] = new_request_count
# Also sync quota_max_requests if set
max_req = model_data.get("quota_max_requests")
if max_req:
other_model_data["quota_max_requests"] = max_req
other_model_data["quota_display"] = (
f"{new_request_count}/{max_req}"
)
key_data["last_failure"] = {
"timestamp": now_ts,
"model": model,
"error": str(classified_error.original_exception),
}
await self._save_usage()
async def update_quota_baseline(
self,
credential: str,
model: str,
remaining_fraction: float,
max_requests: Optional[int] = None,
) -> None:
"""
Update quota baseline data for a credential/model after fetching from API.
This stores the current quota state as a baseline, which is used to
estimate remaining quota based on subsequent request counts.
Args:
credential: Credential identifier (file path or env:// URI)
model: Model name (with or without provider prefix)
remaining_fraction: Current remaining quota as fraction (0.0 to 1.0)
max_requests: Maximum requests allowed per quota period (e.g., 250 for Claude)
"""
await self._lazy_init()
async with self._data_lock:
now_ts = time.time()
# Get or create key data structure
key_data = self._usage_data.setdefault(
credential,
{
"models": {},
"global": {"models": {}},
"model_cooldowns": {},
"failures": {},
},
)
# Ensure models dict exists
if "models" not in key_data:
key_data["models"] = {}
# Get or create per-model data
model_data = key_data["models"].setdefault(
model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
"baseline_remaining_fraction": None,
"baseline_fetched_at": None,
"requests_at_baseline": None,
},
)
# Calculate actual used requests from API's remaining fraction
# The API is authoritative - sync our local count to match reality
if max_requests is not None:
used_requests = int((1.0 - remaining_fraction) * max_requests)
else:
# Estimate max_requests from provider's quota cost
# This matches how get_max_requests_for_model() calculates it
provider = self._get_provider_from_credential(credential)
plugin_instance = self._get_provider_instance(provider)
if plugin_instance and hasattr(
plugin_instance, "get_max_requests_for_model"
):
# Get tier from provider's cache
tier = getattr(plugin_instance, "project_tier_cache", {}).get(
credential, "standard-tier"
)
# Strip provider prefix from model if present
clean_model = model.split("/")[-1] if "/" in model else model
max_requests = plugin_instance.get_max_requests_for_model(
clean_model, tier
)
used_requests = int((1.0 - remaining_fraction) * max_requests)
else:
# Fallback: keep existing count if we can't calculate
used_requests = model_data.get("request_count", 0)
max_requests = model_data.get("quota_max_requests")
# Sync local request count to API's authoritative value
model_data["request_count"] = used_requests
model_data["requests_at_baseline"] = used_requests
# Update baseline fields
model_data["baseline_remaining_fraction"] = remaining_fraction
model_data["baseline_fetched_at"] = now_ts
# Update max_requests and quota_display
if max_requests is not None:
model_data["quota_max_requests"] = max_requests
model_data["quota_display"] = f"{used_requests}/{max_requests}"
# Sync request_count and quota_max_requests across quota group
group = self._get_model_quota_group(credential, model)
if group:
grouped_models = self._get_grouped_models(credential, group)
for grouped_model in grouped_models:
if grouped_model != model:
other_model_data = key_data["models"].setdefault(
grouped_model,
{
"window_start_ts": None,
"quota_reset_ts": None,
"success_count": 0,
"failure_count": 0,
"request_count": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"approx_cost": 0.0,
},
)
other_model_data["request_count"] = used_requests
if max_requests is not None:
other_model_data["quota_max_requests"] = max_requests
other_model_data["quota_display"] = (
f"{used_requests}/{max_requests}"
)
lib_logger.debug(
f"Updated quota baseline for {mask_credential(credential)} model={model}: "
f"remaining={remaining_fraction:.2%}, synced_request_count={used_requests}"
)
await self._save_usage()
async def _check_key_lockout(self, key: str, key_data: Dict):
"""
Checks if a key should be locked out due to multiple model failures.
NOTE: This check is currently disabled. The original logic counted individual
models in long-term lockout, but this caused issues with quota groups - when
a single quota group (e.g., "claude" with 5 models) was exhausted, it would
count as 5 lockouts and trigger key-level lockout, blocking other quota groups
(like gemini) that were still available.
The per-model and per-group cooldowns already handle quota exhaustion properly.
"""
# Disabled - see docstring above
pass
async def get_stats_for_endpoint(
self,
provider_filter: Optional[str] = None,
include_global: bool = True,
) -> Dict[str, Any]:
"""
Get usage stats formatted for the /v1/quota-stats endpoint.
Aggregates data from key_usage.json grouped by provider.
Includes both current period stats and global (lifetime) stats.
Args:
provider_filter: If provided, only return stats for this provider
include_global: If True, include global/lifetime stats alongside current
Returns:
{
"providers": {
"provider_name": {
"credential_count": int,
"active_count": int,
"on_cooldown_count": int,
"total_requests": int,
"tokens": {
"input_cached": int,
"input_uncached": int,
"input_cache_pct": float,
"output": int
},
"approx_cost": float | None,
"credentials": [...],
"global": {...} # If include_global is True
}
},
"summary": {...},
"global_summary": {...}, # If include_global is True
"timestamp": float
}
"""
await self._lazy_init()
now_ts = time.time()
providers: Dict[str, Dict[str, Any]] = {}
# Track global stats separately
global_providers: Dict[str, Dict[str, Any]] = {}
async with self._data_lock:
if not self._usage_data:
return {
"providers": {},
"summary": {
"total_providers": 0,
"total_credentials": 0,
"active_credentials": 0,
"exhausted_credentials": 0,
"total_requests": 0,
"tokens": {
"input_cached": 0,
"input_uncached": 0,
"input_cache_pct": 0,
"output": 0,
},
"approx_total_cost": 0.0,
},
"global_summary": {
"total_providers": 0,
"total_credentials": 0,
"total_requests": 0,
"tokens": {
"input_cached": 0,
"input_uncached": 0,
"input_cache_pct": 0,
"output": 0,
},
"approx_total_cost": 0.0,
},
"data_source": "cache",
"timestamp": now_ts,
}
for credential, cred_data in self._usage_data.items():
# Extract provider from credential path
provider = self._get_provider_from_credential(credential)
if not provider:
continue
# Apply filter if specified
if provider_filter and provider != provider_filter:
continue
# Initialize provider entry
if provider not in providers:
providers[provider] = {
"credential_count": 0,
"active_count": 0,
"on_cooldown_count": 0,
"exhausted_count": 0,
"total_requests": 0,
"tokens": {
"input_cached": 0,
"input_uncached": 0,
"input_cache_pct": 0,
"output": 0,
},
"approx_cost": 0.0,
"credentials": [],
}
global_providers[provider] = {
"total_requests": 0,
"tokens": {
"input_cached": 0,
"input_uncached": 0,
"input_cache_pct": 0,
"output": 0,
},
"approx_cost": 0.0,
}
prov_stats = providers[provider]
prov_stats["credential_count"] += 1
# Determine credential status and cooldowns
key_cooldown = cred_data.get("key_cooldown_until", 0) or 0
model_cooldowns = cred_data.get("model_cooldowns", {})
# Build active cooldowns with remaining time
active_cooldowns = {}
for model, cooldown_ts in model_cooldowns.items():
if cooldown_ts > now_ts:
remaining_seconds = int(cooldown_ts - now_ts)
active_cooldowns[model] = {
"until_ts": cooldown_ts,
"remaining_seconds": remaining_seconds,
}
key_cooldown_remaining = None
if key_cooldown > now_ts:
key_cooldown_remaining = int(key_cooldown - now_ts)
has_active_cooldown = key_cooldown > now_ts or len(active_cooldowns) > 0
# Check if exhausted (all quota groups exhausted for Antigravity)
is_exhausted = False
models_data = cred_data.get("models", {})
if models_data:
# Check if any model has remaining quota
all_exhausted = True
for model_stats in models_data.values():
if isinstance(model_stats, dict):
baseline = model_stats.get("baseline_remaining_fraction")
if baseline is None or baseline > 0:
all_exhausted = False
break
if all_exhausted and len(models_data) > 0:
is_exhausted = True
if is_exhausted:
prov_stats["exhausted_count"] += 1
status = "exhausted"
elif has_active_cooldown:
prov_stats["on_cooldown_count"] += 1
status = "cooldown"
else:
prov_stats["active_count"] += 1
status = "active"
# Aggregate token stats (current period)
cred_tokens = {
"input_cached": 0,
"input_uncached": 0,
"output": 0,
}
cred_requests = 0
cred_cost = 0.0
# Aggregate global token stats
cred_global_tokens = {
"input_cached": 0,
"input_uncached": 0,
"output": 0,
}
cred_global_requests = 0
cred_global_cost = 0.0
# Handle per-model structure (current period)
if models_data:
for model_name, model_stats in models_data.items():
if not isinstance(model_stats, dict):
continue
# Prefer request_count if available and non-zero, else fall back to success+failure
req_count = model_stats.get("request_count", 0)
if req_count > 0:
cred_requests += req_count
else:
cred_requests += model_stats.get("success_count", 0)
cred_requests += model_stats.get("failure_count", 0)
# Token stats - track cached separately
cred_tokens["input_cached"] += model_stats.get(
"prompt_tokens_cached", 0
)
cred_tokens["input_uncached"] += model_stats.get(
"prompt_tokens", 0
)
cred_tokens["output"] += model_stats.get("completion_tokens", 0)
cred_cost += model_stats.get("approx_cost", 0.0)
# Handle legacy daily structure
daily_data = cred_data.get("daily", {})
daily_models = daily_data.get("models", {})
for model_name, model_stats in daily_models.items():
if not isinstance(model_stats, dict):
continue
cred_requests += model_stats.get("success_count", 0)
cred_tokens["input_cached"] += model_stats.get(
"prompt_tokens_cached", 0
)
cred_tokens["input_uncached"] += model_stats.get("prompt_tokens", 0)
cred_tokens["output"] += model_stats.get("completion_tokens", 0)
cred_cost += model_stats.get("approx_cost", 0.0)
# Handle global stats
global_data = cred_data.get("global", {})
global_models = global_data.get("models", {})
for model_name, model_stats in global_models.items():
if not isinstance(model_stats, dict):
continue
cred_global_requests += model_stats.get("success_count", 0)
cred_global_tokens["input_cached"] += model_stats.get(
"prompt_tokens_cached", 0
)
cred_global_tokens["input_uncached"] += model_stats.get(
"prompt_tokens", 0
)
cred_global_tokens["output"] += model_stats.get(
"completion_tokens", 0
)
cred_global_cost += model_stats.get("approx_cost", 0.0)
# Add current period stats to global totals
cred_global_requests += cred_requests
cred_global_tokens["input_cached"] += cred_tokens["input_cached"]
cred_global_tokens["input_uncached"] += cred_tokens["input_uncached"]
cred_global_tokens["output"] += cred_tokens["output"]
cred_global_cost += cred_cost
# Build credential entry
# Mask credential identifier for display
if credential.startswith("env://"):
identifier = credential
else:
identifier = Path(credential).name
cred_entry = {
"identifier": identifier,
"full_path": credential,
"status": status,
"last_used_ts": cred_data.get("last_used_ts"),
"requests": cred_requests,
"tokens": cred_tokens,
"approx_cost": cred_cost if cred_cost > 0 else None,
}
# Add cooldown info
if key_cooldown_remaining is not None:
cred_entry["key_cooldown_remaining"] = key_cooldown_remaining
if active_cooldowns:
cred_entry["model_cooldowns"] = active_cooldowns
# Add global stats for this credential
if include_global:
# Calculate global cache percentage
global_total_input = (
cred_global_tokens["input_cached"]
+ cred_global_tokens["input_uncached"]
)
global_cache_pct = (
round(
cred_global_tokens["input_cached"]
/ global_total_input
* 100,
1,
)
if global_total_input > 0
else 0
)
cred_entry["global"] = {
"requests": cred_global_requests,
"tokens": {
"input_cached": cred_global_tokens["input_cached"],
"input_uncached": cred_global_tokens["input_uncached"],
"input_cache_pct": global_cache_pct,
"output": cred_global_tokens["output"],
},
"approx_cost": cred_global_cost
if cred_global_cost > 0
else None,
}
# Add model-specific data for providers with per-model tracking
if models_data:
cred_entry["models"] = {}
for model_name, model_stats in models_data.items():
if not isinstance(model_stats, dict):
continue
cred_entry["models"][model_name] = {
"requests": model_stats.get("success_count", 0)
+ model_stats.get("failure_count", 0),
"request_count": model_stats.get("request_count", 0),
"success_count": model_stats.get("success_count", 0),
"failure_count": model_stats.get("failure_count", 0),
"prompt_tokens": model_stats.get("prompt_tokens", 0),
"prompt_tokens_cached": model_stats.get(
"prompt_tokens_cached", 0
),
"completion_tokens": model_stats.get(
"completion_tokens", 0
),
"approx_cost": model_stats.get("approx_cost", 0.0),
"window_start_ts": model_stats.get("window_start_ts"),
"quota_reset_ts": model_stats.get("quota_reset_ts"),
# Quota baseline fields (Antigravity-specific)
"baseline_remaining_fraction": model_stats.get(
"baseline_remaining_fraction"
),
"baseline_fetched_at": model_stats.get(
"baseline_fetched_at"
),
"quota_max_requests": model_stats.get("quota_max_requests"),
"quota_display": model_stats.get("quota_display"),
}
prov_stats["credentials"].append(cred_entry)
# Aggregate to provider totals (current period)
prov_stats["total_requests"] += cred_requests
prov_stats["tokens"]["input_cached"] += cred_tokens["input_cached"]
prov_stats["tokens"]["input_uncached"] += cred_tokens["input_uncached"]
prov_stats["tokens"]["output"] += cred_tokens["output"]
if cred_cost > 0:
prov_stats["approx_cost"] += cred_cost
# Aggregate to global provider totals
global_providers[provider]["total_requests"] += cred_global_requests
global_providers[provider]["tokens"]["input_cached"] += (
cred_global_tokens["input_cached"]
)
global_providers[provider]["tokens"]["input_uncached"] += (
cred_global_tokens["input_uncached"]
)
global_providers[provider]["tokens"]["output"] += cred_global_tokens[
"output"
]
global_providers[provider]["approx_cost"] += cred_global_cost
# Calculate cache percentages for each provider
for provider, prov_stats in providers.items():
total_input = (
prov_stats["tokens"]["input_cached"]
+ prov_stats["tokens"]["input_uncached"]
)
if total_input > 0:
prov_stats["tokens"]["input_cache_pct"] = round(
prov_stats["tokens"]["input_cached"] / total_input * 100, 1
)
# Set cost to None if 0
if prov_stats["approx_cost"] == 0:
prov_stats["approx_cost"] = None
# Calculate global cache percentages
if include_global and provider in global_providers:
gp = global_providers[provider]
global_total = (
gp["tokens"]["input_cached"] + gp["tokens"]["input_uncached"]
)
if global_total > 0:
gp["tokens"]["input_cache_pct"] = round(
gp["tokens"]["input_cached"] / global_total * 100, 1
)
if gp["approx_cost"] == 0:
gp["approx_cost"] = None
prov_stats["global"] = gp
# Build summary (current period)
total_creds = sum(p["credential_count"] for p in providers.values())
active_creds = sum(p["active_count"] for p in providers.values())
exhausted_creds = sum(p["exhausted_count"] for p in providers.values())
total_requests = sum(p["total_requests"] for p in providers.values())
total_input_cached = sum(
p["tokens"]["input_cached"] for p in providers.values()
)
total_input_uncached = sum(
p["tokens"]["input_uncached"] for p in providers.values()
)
total_output = sum(p["tokens"]["output"] for p in providers.values())
total_cost = sum(p["approx_cost"] or 0 for p in providers.values())
total_input = total_input_cached + total_input_uncached
input_cache_pct = (
round(total_input_cached / total_input * 100, 1) if total_input > 0 else 0
)
result = {
"providers": providers,
"summary": {
"total_providers": len(providers),
"total_credentials": total_creds,
"active_credentials": active_creds,
"exhausted_credentials": exhausted_creds,
"total_requests": total_requests,
"tokens": {
"input_cached": total_input_cached,
"input_uncached": total_input_uncached,
"input_cache_pct": input_cache_pct,
"output": total_output,
},
"approx_total_cost": total_cost if total_cost > 0 else None,
},
"data_source": "cache",
"timestamp": now_ts,
}
# Build global summary
if include_global:
global_total_requests = sum(
gp["total_requests"] for gp in global_providers.values()
)
global_total_input_cached = sum(
gp["tokens"]["input_cached"] for gp in global_providers.values()
)
global_total_input_uncached = sum(
gp["tokens"]["input_uncached"] for gp in global_providers.values()
)
global_total_output = sum(
gp["tokens"]["output"] for gp in global_providers.values()
)
global_total_cost = sum(
gp["approx_cost"] or 0 for gp in global_providers.values()
)
global_total_input = global_total_input_cached + global_total_input_uncached
global_input_cache_pct = (
round(global_total_input_cached / global_total_input * 100, 1)
if global_total_input > 0
else 0
)
result["global_summary"] = {
"total_providers": len(global_providers),
"total_credentials": total_creds,
"total_requests": global_total_requests,
"tokens": {
"input_cached": global_total_input_cached,
"input_uncached": global_total_input_uncached,
"input_cache_pct": global_input_cache_pct,
"output": global_total_output,
},
"approx_total_cost": global_total_cost
if global_total_cost > 0
else None,
}
return result
async def reload_from_disk(self) -> None:
"""
Force reload usage data from disk.
Useful when another process may have updated the file.
"""
async with self._init_lock:
self._initialized.clear()
await self._load_usage()
await self._reset_daily_stats_if_needed()
self._initialized.set()