""" api_key_loader.py — API key management for Groq and HuggingFace Inference API. IMPROVEMENTS (v2): - ADDED: Persistent state tracking (last_used_key_index, call_counts) - ADDED: Intelligent key rotation with failure awareness - ADDED: Health tracking for each API key - ADDED: State persistence via JSON state file - KEPT: Groq key rotation (enhanced round-robin) - REMOVED: Together AI (unreliable, rate-limited) - REMOVED: Gemini stub (kept for backward compat only) Supported env vars: GROQ_API_KEY — single Groq key GROQ_API_KEYS — comma-separated Groq keys (rotated with tracking) HF_TOKEN — HuggingFace token (optional, increases rate limits) State file: .api_state.json — tracks key usage, failure counts, rotations """ import os import json import logging from typing import List, Optional, Dict, Any from pathlib import Path from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) class APIKeyManager: """ Manages API keys for Groq and HuggingFace with intelligent load balancing. Features: - Round-robin rotation with persistent state tracking - Failure counting per key (skips consistently failing keys) - Call count tracking per key - Dynamic allocation based on key health """ STATE_FILE = ".api_state.json" MAX_KEY_FAILURES = 5 # Skip key after N consecutive failures def __init__(self): self.groq_keys: List[str] = [] self.groq_index: int = 0 self.hf_token: Optional[str] = None # State tracking for load balancing self.state: Dict[str, Any] = { "last_used_groq_index": 0, "groq_call_counts": {}, # key_index -> call_count "groq_failure_counts": {}, # key_index -> consecutive_failure_count "groq_total_calls": 0, "groq_total_failures": 0, } # Legacy stubs — kept for backward compat but unused self.together_keys: List[str] = [] self.together_index: int = 0 self.gemini_keys: List[str] = [] self.gemini_index: int = 0 self._load_keys() self._load_state() def _load_keys(self): """Load API keys from environment.""" # --- Groq --- groq_keys_str = os.getenv("GROQ_API_KEYS", "").strip() if groq_keys_str: self.groq_keys = [k.strip() for k in groq_keys_str.split(",") if k.strip()] if not self.groq_keys: single = os.getenv("GROQ_API_KEY", "").strip() if single: self.groq_keys = [single] # Initialize state for each key for i in range(len(self.groq_keys)): self.state["groq_call_counts"].setdefault(str(i), 0) self.state["groq_failure_counts"].setdefault(str(i), 0) # --- HuggingFace --- self.hf_token = os.getenv("HF_TOKEN", "").strip() or None # --- Legacy Together AI & Gemini (stubs only) --- together_keys_str = os.getenv("TOGETHER_API_KEYS", "").strip() if together_keys_str: self.together_keys = [k.strip() for k in together_keys_str.split(",") if k.strip()] if not self.together_keys: single = os.getenv("TOGETHER_API_KEY", "").strip() if single: self.together_keys = [single] gemini_keys_str = os.getenv("GEMINI_API_KEYS", "").strip() if gemini_keys_str: self.gemini_keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()] if not self.gemini_keys: single = os.getenv("GEMINI_API_KEY", "").strip() if single: self.gemini_keys = [single] def _load_state(self): """Load persisted state from disk.""" state_path = Path(self.STATE_FILE) if state_path.exists(): try: with open(state_path, "r", encoding="utf-8") as f: saved_state = json.load(f) self.state.update(saved_state) self.groq_index = self.state.get("last_used_groq_index", 0) logger.info(f"✓ Loaded API state: Groq index={self.groq_index}, " f"calls={self.state['groq_total_calls']}, " f"failures={self.state['groq_total_failures']}") except Exception as e: logger.warning(f"Failed to load API state: {e}. Starting fresh.") def _save_state(self): """Persist state to disk.""" self.state["last_used_groq_index"] = self.groq_index state_path = Path(self.STATE_FILE) try: with open(state_path, "w", encoding="utf-8") as f: json.dump(self.state, f, indent=2) except Exception as e: logger.warning(f"Failed to save API state: {e}") def _is_key_healthy(self, key_index: int) -> bool: """Check if a key is healthy (not exceeding failure threshold).""" failures = self.state["groq_failure_counts"].get(str(key_index), 0) is_healthy = failures < self.MAX_KEY_FAILURES if not is_healthy: logger.warning(f"⚠ Groq key {key_index} has {failures} failures — skipping") return is_healthy def _get_next_healthy_key_index(self, start_index: int) -> int: """Find the next healthy key starting from start_index.""" for attempt in range(len(self.groq_keys)): candidate = (start_index + attempt) % len(self.groq_keys) if self._is_key_healthy(candidate): return candidate # All keys failing — reset failure counts and use start_index logger.warning("⚠ All Groq keys unhealthy. Resetting failure counts...") for i in range(len(self.groq_keys)): self.state["groq_failure_counts"][str(i)] = 0 return start_index # ================================================================ Groq def get_groq_key(self, skip_index: Optional[int] = None) -> Optional[str]: """ Get next Groq API key with intelligent load balancing. Args: skip_index: If provided, skip this specific key index and return next healthy key Returns: API key string, or None if no keys available """ if not self.groq_keys: return None if skip_index is not None: # Skip the provided index and find next healthy key self.groq_index = self._get_next_healthy_key_index(skip_index + 1) logger.info(f"⟲ Skipped Groq key {skip_index}, using key {self.groq_index}") else: # Normal round-robin to next healthy key self.groq_index = self._get_next_healthy_key_index(self.groq_index) key = self.groq_keys[self.groq_index] # Advance index for next call next_index = (self.groq_index + 1) % len(self.groq_keys) self.groq_index = self._get_next_healthy_key_index(next_index) return key def track_groq_call(self, key: str, success: bool = True): """ Track an actual Groq API call (not just key retrieval). Call this AFTER making an actual API request. Args: key: The API key that was used success: Whether the call succeeded """ try: key_index = self.groq_keys.index(key) key_idx_str = str(key_index) self.state["groq_call_counts"][key_idx_str] = self.state["groq_call_counts"].get(key_idx_str, 0) + 1 self.state["groq_total_calls"] += 1 if success: # Reset failures on success self.state["groq_failure_counts"][key_idx_str] = 0 else: # Increment failures self.state["groq_failure_counts"][key_idx_str] = self.state["groq_failure_counts"].get(key_idx_str, 0) + 1 self.state["groq_total_failures"] += 1 self._save_state() except ValueError: logger.warning("Call tracked for unknown Groq key") def mark_groq_key_failure(self, key: str): """Record a failure for a Groq key (DEPRECATED — use track_groq_call instead).""" self.track_groq_call(key, success=False) def mark_groq_key_success(self, key: str): """Clear failures for a Groq key after successful use (DEPRECATED — use track_groq_call instead).""" self.track_groq_call(key, success=True) def get_groq_key_count(self) -> int: """Return total number of Groq keys.""" return len(self.groq_keys) def get_groq_stats(self) -> Dict[str, Any]: """Get detailed statistics about Groq key usage.""" stats = { "total_keys": len(self.groq_keys), "total_calls": self.state["groq_total_calls"], "total_failures": self.state["groq_total_failures"], "current_index": self.groq_index, "per_key": {} } for i in range(len(self.groq_keys)): idx_str = str(i) stats["per_key"][i] = { "calls": self.state["groq_call_counts"].get(idx_str, 0), "failures": self.state["groq_failure_counts"].get(idx_str, 0), "healthy": self._is_key_healthy(i), } return stats def is_groq_available(self) -> bool: """Check if any Groq keys are available.""" return len(self.groq_keys) > 0 # ================================================== HuggingFace def get_hf_token(self) -> Optional[str]: """Get HuggingFace token.""" return self.hf_token def is_hf_available(self) -> bool: """HF Inference API works without a token (rate-limited).""" return True # ================================================== Legacy stubs def get_together_key(self, skip_index: Optional[int] = None) -> Optional[str]: """Stub — Together AI removed. Returns None.""" return None def get_together_key_count(self) -> int: return 0 def is_together_available(self) -> bool: return False def get_gemini_key(self, skip_index: Optional[int] = None) -> Optional[str]: """Stub — Gemini removed. Returns None.""" return None def get_gemini_key_count(self) -> int: return 0 def is_gemini_available(self) -> bool: return False # Global singleton _api_manager = APIKeyManager() def get_api_manager() -> APIKeyManager: return _api_manager # Convenience wrappers def get_next_groq_key(skip_index: Optional[int] = None) -> Optional[str]: return get_api_manager().get_groq_key(skip_index) def get_hf_token() -> Optional[str]: return get_api_manager().get_hf_token() # Legacy stubs (kept so old imports don't break) def get_next_together_key(skip_index: Optional[int] = None) -> Optional[str]: return None def get_next_gemini_key(skip_index: Optional[int] = None) -> Optional[str]: return None