""" Groq API key pool manager with dynamic routing and persisted usage stats. This module supports: - GROQ_API_KEYS (comma/semicolon/newline-separated values) - GROQ_API_KEY (single key fallback) - Dynamic key selection based on prior usage and failures - Failure cooldowns with exponential backoff - Persistent state for cross-run adaptation """ from __future__ import annotations import hashlib import json import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from groq import Groq def _split_keys(raw: str) -> List[str]: normalized = raw.replace(";", ",").replace("\n", ",") keys = [token.strip() for token in normalized.split(",") if token.strip()] # Deduplicate while preserving order seen = set() unique_keys: List[str] = [] for key in keys: if key not in seen: seen.add(key) unique_keys.append(key) return unique_keys def parse_groq_keys(api_key: str = "", api_keys_csv: str = "") -> List[str]: """Parse key inputs from env values or direct args.""" values: List[str] = [] if api_keys_csv.strip(): values.extend(_split_keys(api_keys_csv)) if api_key.strip() and api_key.strip() not in values: values.append(api_key.strip()) return values class GroqKeyPool: """Usage-aware key pool with cooldowns and persistent stats.""" def __init__( self, api_keys: List[str], base_url: str, state_file: Path, cooldown_base_seconds: int = 30, ) -> None: self._base_url = base_url self._state_file = state_file self._cooldown_base_seconds = cooldown_base_seconds self._keys_by_id: Dict[str, str] = {} self._clients_by_id: Dict[str, Groq] = {} self._stats: Dict[str, Dict[str, Any]] = {} for key in api_keys: key_id = self._fingerprint(key) self._keys_by_id[key_id] = key self._clients_by_id[key_id] = Groq(api_key=key, base_url=base_url) self._load_state() self._ensure_stat_entries() @staticmethod def _fingerprint(key: str) -> str: digest = hashlib.sha1(key.encode("utf-8")).hexdigest()[:10] return f"k_{digest}" def _default_stats(self) -> Dict[str, Any]: return { "requests": 0, "successes": 0, "failures": 0, "consecutive_failures": 0, "last_used_at": 0.0, "last_success_at": 0.0, "cooldown_until": 0.0, "last_error": "", } def _load_state(self) -> None: if not self._state_file.exists(): self._stats = {} return try: with open(self._state_file, "r", encoding="utf-8") as file: data = json.load(file) if isinstance(data, dict): self._stats = data else: self._stats = {} except Exception: self._stats = {} def _save_state(self) -> None: self._state_file.parent.mkdir(parents=True, exist_ok=True) with open(self._state_file, "w", encoding="utf-8") as file: json.dump(self._stats, file, indent=2) def _ensure_stat_entries(self) -> None: key_ids = set(self._keys_by_id.keys()) # Remove stale keys from state self._stats = {key_id: stats for key_id, stats in self._stats.items() if key_id in key_ids} # Add missing keys for key_id in key_ids: if key_id not in self._stats: self._stats[key_id] = self._default_stats() self._save_state() def _now(self) -> float: return time.time() def _is_cooling(self, key_id: str, now: float) -> bool: return float(self._stats[key_id].get("cooldown_until", 0.0)) > now def _selection_tuple(self, key_id: str, now: float) -> Tuple[int, float, int, int, float]: stats = self._stats[key_id] cooling_rank = 1 if self._is_cooling(key_id, now) else 0 cooldown_until = float(stats.get("cooldown_until", 0.0)) consecutive_failures = int(stats.get("consecutive_failures", 0)) requests = int(stats.get("requests", 0)) last_used_at = float(stats.get("last_used_at", 0.0)) return (cooling_rank, cooldown_until, consecutive_failures, requests, last_used_at) def acquire_key(self) -> Optional[str]: if not self._keys_by_id: return None now = self._now() ordered = sorted(self._keys_by_id.keys(), key=lambda key_id: self._selection_tuple(key_id, now)) return ordered[0] if ordered else None def get_client(self, key_id: str) -> Groq: return self._clients_by_id[key_id] def mark_request(self, key_id: str) -> None: stats = self._stats[key_id] stats["requests"] = int(stats.get("requests", 0)) + 1 stats["last_used_at"] = self._now() self._save_state() def mark_success(self, key_id: str) -> None: stats = self._stats[key_id] stats["successes"] = int(stats.get("successes", 0)) + 1 stats["consecutive_failures"] = 0 stats["last_success_at"] = self._now() stats["last_error"] = "" stats["cooldown_until"] = 0.0 self._save_state() def mark_failure(self, key_id: str, error_text: str) -> None: stats = self._stats[key_id] stats["failures"] = int(stats.get("failures", 0)) + 1 stats["consecutive_failures"] = int(stats.get("consecutive_failures", 0)) + 1 stats["last_error"] = error_text[:300] error_lower = error_text.lower() should_cooldown = any(token in error_lower for token in [ "rate limit", "429", "quota", "temporarily unavailable", "timeout", "connection", "auth", "unauthorized", "forbidden", ]) if should_cooldown: multiplier = min(int(stats["consecutive_failures"]), 5) cooldown_seconds = self._cooldown_base_seconds * (2 ** (multiplier - 1)) stats["cooldown_until"] = self._now() + cooldown_seconds self._save_state() def snapshot(self) -> Dict[str, Any]: now = self._now() redacted = {} for key_id, stats in self._stats.items(): redacted[key_id] = { "requests": int(stats.get("requests", 0)), "successes": int(stats.get("successes", 0)), "failures": int(stats.get("failures", 0)), "consecutive_failures": int(stats.get("consecutive_failures", 0)), "cooling": float(stats.get("cooldown_until", 0.0)) > now, "last_error": stats.get("last_error", ""), } return { "total_keys": len(self._keys_by_id), "keys": redacted, }