import json import os import time import logging import asyncio from datetime import date, datetime, timezone, time as dt_time from typing import Dict, List, Optional, Set import aiofiles import litellm from .error_handler import ClassifiedError lib_logger = logging.getLogger('rotator_library') lib_logger.propagate = False if not lib_logger.handlers: lib_logger.addHandler(logging.NullHandler()) class UsageManager: """ Manages usage statistics and cooldowns for API keys with asyncio-safe locking, asynchronous file I/O, and a lazy-loading mechanism for usage data. """ def __init__(self, file_path: str = "key_usage.json", wait_timeout: int = 13, daily_reset_time_utc: Optional[str] = "03:00"): self.file_path = file_path self.key_states: Dict[str, Dict[str, Any]] = {} self.wait_timeout = wait_timeout self._data_lock = asyncio.Lock() self._usage_data: Optional[Dict] = None self._initialized = asyncio.Event() self._init_lock = asyncio.Lock() self._timeout_lock = asyncio.Lock() self._claimed_on_timeout: Set[str] = set() if daily_reset_time_utc: hour, minute = map(int, daily_reset_time_utc.split(':')) self.daily_reset_time_utc = dt_time(hour=hour, minute=minute, tzinfo=timezone.utc) else: self.daily_reset_time_utc = None async def _lazy_init(self): """Initializes the usage data by loading it from the file asynchronously.""" async with self._init_lock: if not self._initialized.is_set(): await self._load_usage() await self._reset_daily_stats_if_needed() self._initialized.set() async def _load_usage(self): """Loads usage data from the JSON file asynchronously.""" async with self._data_lock: if not os.path.exists(self.file_path): self._usage_data = {} return try: async with aiofiles.open(self.file_path, 'r') as f: content = await f.read() self._usage_data = json.loads(content) except (json.JSONDecodeError, IOError, FileNotFoundError): self._usage_data = {} async def _save_usage(self): """Saves the current usage data to the JSON file asynchronously.""" if self._usage_data is None: return async with self._data_lock: async with aiofiles.open(self.file_path, 'w') as f: await f.write(json.dumps(self._usage_data, indent=2)) async def _reset_daily_stats_if_needed(self): """Checks if daily stats need to be reset for any key.""" if self._usage_data is None or not self.daily_reset_time_utc: return now_utc = datetime.now(timezone.utc) today_str = now_utc.date().isoformat() needs_saving = False for key, data in self._usage_data.items(): last_reset_str = data.get("last_daily_reset", "") if last_reset_str != today_str: last_reset_dt = None if last_reset_str: # Ensure the parsed datetime is timezone-aware (UTC) last_reset_dt = datetime.fromisoformat(last_reset_str).replace(tzinfo=timezone.utc) # Determine the reset threshold for today reset_threshold_today = datetime.combine(now_utc.date(), self.daily_reset_time_utc) if last_reset_dt is None or last_reset_dt < reset_threshold_today <= now_utc: lib_logger.info(f"Performing daily reset for key ...{key[-4:]}") needs_saving = True # Reset cooldowns data["model_cooldowns"] = {} data["key_cooldown_until"] = None # Reset consecutive failures if "failures" in data: data["failures"] = {} # Archive global stats from the previous day's 'daily' daily_data = data.get("daily", {}) if daily_data: global_data = data.setdefault("global", {"models": {}}) for model, stats in daily_data.get("models", {}).items(): global_model_stats = global_data["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0}) global_model_stats["success_count"] += stats.get("success_count", 0) global_model_stats["prompt_tokens"] += stats.get("prompt_tokens", 0) global_model_stats["completion_tokens"] += stats.get("completion_tokens", 0) global_model_stats["approx_cost"] += stats.get("approx_cost", 0.0) # Reset daily stats data["daily"] = {"date": today_str, "models": {}} data["last_daily_reset"] = today_str if needs_saving: await self._save_usage() def _initialize_key_states(self, keys: List[str]): """Initializes state tracking for all provided keys if not already present.""" for key in keys: if key not in self.key_states: self.key_states[key] = { "lock": asyncio.Lock(), "condition": asyncio.Condition(), "models_in_use": set() } async def acquire_key(self, available_keys: List[str], model: str) -> str: """ Acquires the best available key using a tiered, model-aware locking strategy. """ await self._lazy_init() self._initialize_key_states(available_keys) while True: tier1_keys, tier2_keys = [], [] async with self._data_lock: now = time.time() for key in available_keys: key_data = self._usage_data.get(key, {}) # Skip keys on global or model-specific cooldown if (key_data.get("key_cooldown_until") or 0) > now or \ (key_data.get("model_cooldowns", {}).get(model) or 0) > now: continue usage_count = key_data.get("daily", {}).get("models", {}).get(model, {}).get("success_count", 0) key_state = self.key_states[key] if not key_state["models_in_use"]: tier1_keys.append((key, usage_count)) elif model not in key_state["models_in_use"]: tier2_keys.append((key, usage_count)) # Sort keys by usage count (ascending) tier1_keys.sort(key=lambda x: x[1]) tier2_keys.sort(key=lambda x: x[1]) # Attempt to acquire from Tier 1 (completely free) for key, _ in tier1_keys: state = self.key_states[key] async with state["lock"]: if not state["models_in_use"]: state["models_in_use"].add(model) lib_logger.info(f"Acquired Tier 1 key ...{key[-4:]} for model {model}") return key # Attempt to acquire from Tier 2 (in use by other models) for key, _ in tier2_keys: state = self.key_states[key] async with state["lock"]: if model not in state["models_in_use"]: state["models_in_use"].add(model) lib_logger.info(f"Acquired Tier 2 key ...{key[-4:]} for model {model}") return key # If no key is available, wait for one to be released lib_logger.info("All eligible keys are currently locked for this model. Waiting...") # Create a combined list of all potentially usable keys to wait on all_potential_keys = tier1_keys + tier2_keys if not all_potential_keys: lib_logger.warning("No keys are eligible at all (all on cooldown). Waiting before re-evaluating.") await asyncio.sleep(5) continue # Wait on the condition of the best available key best_wait_key = min(all_potential_keys, key=lambda x: x[1])[0] wait_condition = self.key_states[best_wait_key]["condition"] try: async with wait_condition: await asyncio.wait_for(wait_condition.wait(), timeout=self.wait_timeout) lib_logger.info("Notified that a key was released. Re-evaluating...") except asyncio.TimeoutError: lib_logger.warning("Wait timed out. Re-evaluating for any available key.") async def release_key(self, key: str, model: str): """Releases a key's lock for a specific model and notifies waiting tasks.""" if key not in self.key_states: return state = self.key_states[key] async with state["lock"]: if model in state["models_in_use"]: state["models_in_use"].remove(model) lib_logger.info(f"Released key ...{key[-4:]} from model {model}") else: lib_logger.warning(f"Attempted to release key ...{key[-4:]} for model {model}, but it was not in use.") # Notify all tasks waiting on this key's condition async with state["condition"]: state["condition"].notify_all() async def record_success(self, key: str, model: str, completion_response: Optional[litellm.ModelResponse] = None): """ Records a successful API call, resetting failure counters. It safely handles cases where token usage data is not available. """ await self._lazy_init() async with self._data_lock: today_utc_str = datetime.now(timezone.utc).date().isoformat() key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}}) # Perform a just-in-time daily reset if the date has changed. if key_data["daily"].get("date") != today_utc_str: key_data["daily"] = {"date": today_utc_str, "models": {}} # Always record a success and reset failures model_failures = key_data.setdefault("failures", {}).setdefault(model, {}) model_failures["consecutive_failures"] = 0 if model in key_data.get("model_cooldowns", {}): del key_data["model_cooldowns"][model] daily_model_data = key_data["daily"]["models"].setdefault(model, {"success_count": 0, "prompt_tokens": 0, "completion_tokens": 0, "approx_cost": 0.0}) daily_model_data["success_count"] += 1 # Safely attempt to record token and cost usage if completion_response and hasattr(completion_response, 'usage') and completion_response.usage: usage = completion_response.usage daily_model_data["prompt_tokens"] += usage.prompt_tokens daily_model_data["completion_tokens"] += usage.completion_tokens try: cost = litellm.completion_cost(completion_response=completion_response) daily_model_data["approx_cost"] += cost except Exception as e: lib_logger.warning(f"Could not calculate cost for model {model}: {e}") else: lib_logger.warning(f"No usage data found in completion response for model {model}. Recording success without token count.") key_data["last_used_ts"] = time.time() await self._save_usage() async def record_failure(self, key: str, model: str, classified_error: ClassifiedError): """Records a failure and applies cooldowns based on an escalating backoff strategy.""" await self._lazy_init() async with self._data_lock: today_utc_str = datetime.now(timezone.utc).date().isoformat() key_data = self._usage_data.setdefault(key, {"daily": {"date": today_utc_str, "models": {}}, "global": {"models": {}}, "model_cooldowns": {}, "failures": {}}) # Handle specific error types first if classified_error.error_type == 'rate_limit' and classified_error.retry_after: cooldown_seconds = classified_error.retry_after elif classified_error.error_type == 'authentication': # Apply a 5-minute key-level lockout for auth errors key_data["key_cooldown_until"] = time.time() + 300 lib_logger.warning(f"Authentication error on key ...{key[-4:]}. Applying 5-minute key-level lockout.") await self._save_usage() return # No further backoff logic needed else: # General backoff logic for other errors failures_data = key_data.setdefault("failures", {}) model_failures = failures_data.setdefault(model, {"consecutive_failures": 0}) model_failures["consecutive_failures"] += 1 count = model_failures["consecutive_failures"] backoff_tiers = {1: 10, 2: 30, 3: 60, 4: 120} cooldown_seconds = backoff_tiers.get(count, 7200) # Default to 2 hours # Apply the cooldown model_cooldowns = key_data.setdefault("model_cooldowns", {}) model_cooldowns[model] = time.time() + cooldown_seconds lib_logger.warning(f"Failure recorded for key ...{key[-4:]} with model {model}. Applying {cooldown_seconds}s cooldown.") # Check for key-level lockout condition await self._check_key_lockout(key, key_data) key_data["last_failure"] = { "timestamp": time.time(), "model": model, "error": str(classified_error.original_exception) } await self._save_usage() async def _check_key_lockout(self, key: str, key_data: Dict): """Checks if a key should be locked out due to multiple model failures.""" long_term_lockout_models = 0 now = time.time() for model, cooldown_end in key_data.get("model_cooldowns", {}).items(): if cooldown_end - now >= 7200: # Check for 2-hour lockouts long_term_lockout_models += 1 if long_term_lockout_models >= 3: key_data["key_cooldown_until"] = now + 300 # 5-minute key lockout lib_logger.error(f"Key ...{key[-4:]} has {long_term_lockout_models} models in long-term lockout. Applying 5-minute key-level lockout.")