Spaces:
Running
Running
| """ | |
| api_key_loader.py — API key management for Groq and HuggingFace Inference API. | |
| IMPROVEMENTS (v2): | |
| - ADDED: Persistent state tracking (last_used_key_index, call_counts) | |
| - ADDED: Intelligent key rotation with failure awareness | |
| - ADDED: Health tracking for each API key | |
| - ADDED: State persistence via JSON state file | |
| - KEPT: Groq key rotation (enhanced round-robin) | |
| - REMOVED: Together AI (unreliable, rate-limited) | |
| - REMOVED: Gemini stub (kept for backward compat only) | |
| Supported env vars: | |
| GROQ_API_KEY — single Groq key | |
| GROQ_API_KEYS — comma-separated Groq keys (rotated with tracking) | |
| HF_TOKEN — HuggingFace token (optional, increases rate limits) | |
| State file: .api_state.json — tracks key usage, failure counts, rotations | |
| """ | |
| import os | |
| import json | |
| import logging | |
| from typing import List, Optional, Dict, Any | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| class APIKeyManager: | |
| """ | |
| Manages API keys for Groq and HuggingFace with intelligent load balancing. | |
| Features: | |
| - Round-robin rotation with persistent state tracking | |
| - Failure counting per key (skips consistently failing keys) | |
| - Call count tracking per key | |
| - Dynamic allocation based on key health | |
| """ | |
| STATE_FILE = ".api_state.json" | |
| MAX_KEY_FAILURES = 5 # Skip key after N consecutive failures | |
| def __init__(self): | |
| self.groq_keys: List[str] = [] | |
| self.groq_index: int = 0 | |
| self.hf_token: Optional[str] = None | |
| # State tracking for load balancing | |
| self.state: Dict[str, Any] = { | |
| "last_used_groq_index": 0, | |
| "groq_call_counts": {}, # key_index -> call_count | |
| "groq_failure_counts": {}, # key_index -> consecutive_failure_count | |
| "groq_total_calls": 0, | |
| "groq_total_failures": 0, | |
| } | |
| # Legacy stubs — kept for backward compat but unused | |
| self.together_keys: List[str] = [] | |
| self.together_index: int = 0 | |
| self.gemini_keys: List[str] = [] | |
| self.gemini_index: int = 0 | |
| self._load_keys() | |
| self._load_state() | |
| def _load_keys(self): | |
| """Load API keys from environment.""" | |
| # --- Groq --- | |
| groq_keys_str = os.getenv("GROQ_API_KEYS", "").strip() | |
| if groq_keys_str: | |
| self.groq_keys = [k.strip() for k in groq_keys_str.split(",") if k.strip()] | |
| if not self.groq_keys: | |
| single = os.getenv("GROQ_API_KEY", "").strip() | |
| if single: | |
| self.groq_keys = [single] | |
| # Initialize state for each key | |
| for i in range(len(self.groq_keys)): | |
| self.state["groq_call_counts"].setdefault(str(i), 0) | |
| self.state["groq_failure_counts"].setdefault(str(i), 0) | |
| # --- HuggingFace --- | |
| self.hf_token = os.getenv("HF_TOKEN", "").strip() or None | |
| # --- Legacy Together AI & Gemini (stubs only) --- | |
| together_keys_str = os.getenv("TOGETHER_API_KEYS", "").strip() | |
| if together_keys_str: | |
| self.together_keys = [k.strip() for k in together_keys_str.split(",") if k.strip()] | |
| if not self.together_keys: | |
| single = os.getenv("TOGETHER_API_KEY", "").strip() | |
| if single: | |
| self.together_keys = [single] | |
| gemini_keys_str = os.getenv("GEMINI_API_KEYS", "").strip() | |
| if gemini_keys_str: | |
| self.gemini_keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()] | |
| if not self.gemini_keys: | |
| single = os.getenv("GEMINI_API_KEY", "").strip() | |
| if single: | |
| self.gemini_keys = [single] | |
| def _load_state(self): | |
| """Load persisted state from disk.""" | |
| state_path = Path(self.STATE_FILE) | |
| if state_path.exists(): | |
| try: | |
| with open(state_path, "r", encoding="utf-8") as f: | |
| saved_state = json.load(f) | |
| self.state.update(saved_state) | |
| self.groq_index = self.state.get("last_used_groq_index", 0) | |
| logger.info(f"✓ Loaded API state: Groq index={self.groq_index}, " | |
| f"calls={self.state['groq_total_calls']}, " | |
| f"failures={self.state['groq_total_failures']}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load API state: {e}. Starting fresh.") | |
| def _save_state(self): | |
| """Persist state to disk.""" | |
| self.state["last_used_groq_index"] = self.groq_index | |
| state_path = Path(self.STATE_FILE) | |
| try: | |
| with open(state_path, "w", encoding="utf-8") as f: | |
| json.dump(self.state, f, indent=2) | |
| except Exception as e: | |
| logger.warning(f"Failed to save API state: {e}") | |
| def _is_key_healthy(self, key_index: int) -> bool: | |
| """Check if a key is healthy (not exceeding failure threshold).""" | |
| failures = self.state["groq_failure_counts"].get(str(key_index), 0) | |
| is_healthy = failures < self.MAX_KEY_FAILURES | |
| if not is_healthy: | |
| logger.warning(f"⚠ Groq key {key_index} has {failures} failures — skipping") | |
| return is_healthy | |
| def _get_next_healthy_key_index(self, start_index: int) -> int: | |
| """Find the next healthy key starting from start_index.""" | |
| for attempt in range(len(self.groq_keys)): | |
| candidate = (start_index + attempt) % len(self.groq_keys) | |
| if self._is_key_healthy(candidate): | |
| return candidate | |
| # All keys failing — reset failure counts and use start_index | |
| logger.warning("⚠ All Groq keys unhealthy. Resetting failure counts...") | |
| for i in range(len(self.groq_keys)): | |
| self.state["groq_failure_counts"][str(i)] = 0 | |
| return start_index | |
| # ================================================================ Groq | |
| def get_groq_key(self, skip_index: Optional[int] = None) -> Optional[str]: | |
| """ | |
| Get next Groq API key with intelligent load balancing. | |
| Args: | |
| skip_index: If provided, skip this specific key index and return next healthy key | |
| Returns: | |
| API key string, or None if no keys available | |
| """ | |
| if not self.groq_keys: | |
| return None | |
| if skip_index is not None: | |
| # Skip the provided index and find next healthy key | |
| self.groq_index = self._get_next_healthy_key_index(skip_index + 1) | |
| logger.info(f"⟲ Skipped Groq key {skip_index}, using key {self.groq_index}") | |
| else: | |
| # Normal round-robin to next healthy key | |
| self.groq_index = self._get_next_healthy_key_index(self.groq_index) | |
| key = self.groq_keys[self.groq_index] | |
| # Advance index for next call | |
| next_index = (self.groq_index + 1) % len(self.groq_keys) | |
| self.groq_index = self._get_next_healthy_key_index(next_index) | |
| return key | |
| def track_groq_call(self, key: str, success: bool = True): | |
| """ | |
| Track an actual Groq API call (not just key retrieval). | |
| Call this AFTER making an actual API request. | |
| Args: | |
| key: The API key that was used | |
| success: Whether the call succeeded | |
| """ | |
| try: | |
| key_index = self.groq_keys.index(key) | |
| key_idx_str = str(key_index) | |
| self.state["groq_call_counts"][key_idx_str] = self.state["groq_call_counts"].get(key_idx_str, 0) + 1 | |
| self.state["groq_total_calls"] += 1 | |
| if success: | |
| # Reset failures on success | |
| self.state["groq_failure_counts"][key_idx_str] = 0 | |
| else: | |
| # Increment failures | |
| self.state["groq_failure_counts"][key_idx_str] = self.state["groq_failure_counts"].get(key_idx_str, 0) + 1 | |
| self.state["groq_total_failures"] += 1 | |
| self._save_state() | |
| except ValueError: | |
| logger.warning("Call tracked for unknown Groq key") | |
| def mark_groq_key_failure(self, key: str): | |
| """Record a failure for a Groq key (DEPRECATED — use track_groq_call instead).""" | |
| self.track_groq_call(key, success=False) | |
| def mark_groq_key_success(self, key: str): | |
| """Clear failures for a Groq key after successful use (DEPRECATED — use track_groq_call instead).""" | |
| self.track_groq_call(key, success=True) | |
| def get_groq_key_count(self) -> int: | |
| """Return total number of Groq keys.""" | |
| return len(self.groq_keys) | |
| def get_groq_stats(self) -> Dict[str, Any]: | |
| """Get detailed statistics about Groq key usage.""" | |
| stats = { | |
| "total_keys": len(self.groq_keys), | |
| "total_calls": self.state["groq_total_calls"], | |
| "total_failures": self.state["groq_total_failures"], | |
| "current_index": self.groq_index, | |
| "per_key": {} | |
| } | |
| for i in range(len(self.groq_keys)): | |
| idx_str = str(i) | |
| stats["per_key"][i] = { | |
| "calls": self.state["groq_call_counts"].get(idx_str, 0), | |
| "failures": self.state["groq_failure_counts"].get(idx_str, 0), | |
| "healthy": self._is_key_healthy(i), | |
| } | |
| return stats | |
| def is_groq_available(self) -> bool: | |
| """Check if any Groq keys are available.""" | |
| return len(self.groq_keys) > 0 | |
| # ================================================== HuggingFace | |
| def get_hf_token(self) -> Optional[str]: | |
| """Get HuggingFace token.""" | |
| return self.hf_token | |
| def is_hf_available(self) -> bool: | |
| """HF Inference API works without a token (rate-limited).""" | |
| return True | |
| # ================================================== Legacy stubs | |
| def get_together_key(self, skip_index: Optional[int] = None) -> Optional[str]: | |
| """Stub — Together AI removed. Returns None.""" | |
| return None | |
| def get_together_key_count(self) -> int: | |
| return 0 | |
| def is_together_available(self) -> bool: | |
| return False | |
| def get_gemini_key(self, skip_index: Optional[int] = None) -> Optional[str]: | |
| """Stub — Gemini removed. Returns None.""" | |
| return None | |
| def get_gemini_key_count(self) -> int: | |
| return 0 | |
| def is_gemini_available(self) -> bool: | |
| return False | |
| # Global singleton | |
| _api_manager = APIKeyManager() | |
| def get_api_manager() -> APIKeyManager: | |
| return _api_manager | |
| # Convenience wrappers | |
| def get_next_groq_key(skip_index: Optional[int] = None) -> Optional[str]: | |
| return get_api_manager().get_groq_key(skip_index) | |
| def get_hf_token() -> Optional[str]: | |
| return get_api_manager().get_hf_token() | |
| # Legacy stubs (kept so old imports don't break) | |
| def get_next_together_key(skip_index: Optional[int] = None) -> Optional[str]: | |
| return None | |
| def get_next_gemini_key(skip_index: Optional[int] = None) -> Optional[str]: | |
| return None | |