topic-modelling / api_key_loader.py
vvinayakkk's picture
Initial clean commit with LFS
a1d17f8
"""
api_key_loader.py — API key management for Groq and HuggingFace Inference API.
IMPROVEMENTS (v2):
- ADDED: Persistent state tracking (last_used_key_index, call_counts)
- ADDED: Intelligent key rotation with failure awareness
- ADDED: Health tracking for each API key
- ADDED: State persistence via JSON state file
- KEPT: Groq key rotation (enhanced round-robin)
- REMOVED: Together AI (unreliable, rate-limited)
- REMOVED: Gemini stub (kept for backward compat only)
Supported env vars:
GROQ_API_KEY — single Groq key
GROQ_API_KEYS — comma-separated Groq keys (rotated with tracking)
HF_TOKEN — HuggingFace token (optional, increases rate limits)
State file: .api_state.json — tracks key usage, failure counts, rotations
"""
import os
import json
import logging
from typing import List, Optional, Dict, Any
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
class APIKeyManager:
"""
Manages API keys for Groq and HuggingFace with intelligent load balancing.
Features:
- Round-robin rotation with persistent state tracking
- Failure counting per key (skips consistently failing keys)
- Call count tracking per key
- Dynamic allocation based on key health
"""
STATE_FILE = ".api_state.json"
MAX_KEY_FAILURES = 5 # Skip key after N consecutive failures
def __init__(self):
self.groq_keys: List[str] = []
self.groq_index: int = 0
self.hf_token: Optional[str] = None
# State tracking for load balancing
self.state: Dict[str, Any] = {
"last_used_groq_index": 0,
"groq_call_counts": {}, # key_index -> call_count
"groq_failure_counts": {}, # key_index -> consecutive_failure_count
"groq_total_calls": 0,
"groq_total_failures": 0,
}
# Legacy stubs — kept for backward compat but unused
self.together_keys: List[str] = []
self.together_index: int = 0
self.gemini_keys: List[str] = []
self.gemini_index: int = 0
self._load_keys()
self._load_state()
def _load_keys(self):
"""Load API keys from environment."""
# --- Groq ---
groq_keys_str = os.getenv("GROQ_API_KEYS", "").strip()
if groq_keys_str:
self.groq_keys = [k.strip() for k in groq_keys_str.split(",") if k.strip()]
if not self.groq_keys:
single = os.getenv("GROQ_API_KEY", "").strip()
if single:
self.groq_keys = [single]
# Initialize state for each key
for i in range(len(self.groq_keys)):
self.state["groq_call_counts"].setdefault(str(i), 0)
self.state["groq_failure_counts"].setdefault(str(i), 0)
# --- HuggingFace ---
self.hf_token = os.getenv("HF_TOKEN", "").strip() or None
# --- Legacy Together AI & Gemini (stubs only) ---
together_keys_str = os.getenv("TOGETHER_API_KEYS", "").strip()
if together_keys_str:
self.together_keys = [k.strip() for k in together_keys_str.split(",") if k.strip()]
if not self.together_keys:
single = os.getenv("TOGETHER_API_KEY", "").strip()
if single:
self.together_keys = [single]
gemini_keys_str = os.getenv("GEMINI_API_KEYS", "").strip()
if gemini_keys_str:
self.gemini_keys = [k.strip() for k in gemini_keys_str.split(",") if k.strip()]
if not self.gemini_keys:
single = os.getenv("GEMINI_API_KEY", "").strip()
if single:
self.gemini_keys = [single]
def _load_state(self):
"""Load persisted state from disk."""
state_path = Path(self.STATE_FILE)
if state_path.exists():
try:
with open(state_path, "r", encoding="utf-8") as f:
saved_state = json.load(f)
self.state.update(saved_state)
self.groq_index = self.state.get("last_used_groq_index", 0)
logger.info(f"✓ Loaded API state: Groq index={self.groq_index}, "
f"calls={self.state['groq_total_calls']}, "
f"failures={self.state['groq_total_failures']}")
except Exception as e:
logger.warning(f"Failed to load API state: {e}. Starting fresh.")
def _save_state(self):
"""Persist state to disk."""
self.state["last_used_groq_index"] = self.groq_index
state_path = Path(self.STATE_FILE)
try:
with open(state_path, "w", encoding="utf-8") as f:
json.dump(self.state, f, indent=2)
except Exception as e:
logger.warning(f"Failed to save API state: {e}")
def _is_key_healthy(self, key_index: int) -> bool:
"""Check if a key is healthy (not exceeding failure threshold)."""
failures = self.state["groq_failure_counts"].get(str(key_index), 0)
is_healthy = failures < self.MAX_KEY_FAILURES
if not is_healthy:
logger.warning(f"⚠ Groq key {key_index} has {failures} failures — skipping")
return is_healthy
def _get_next_healthy_key_index(self, start_index: int) -> int:
"""Find the next healthy key starting from start_index."""
for attempt in range(len(self.groq_keys)):
candidate = (start_index + attempt) % len(self.groq_keys)
if self._is_key_healthy(candidate):
return candidate
# All keys failing — reset failure counts and use start_index
logger.warning("⚠ All Groq keys unhealthy. Resetting failure counts...")
for i in range(len(self.groq_keys)):
self.state["groq_failure_counts"][str(i)] = 0
return start_index
# ================================================================ Groq
def get_groq_key(self, skip_index: Optional[int] = None) -> Optional[str]:
"""
Get next Groq API key with intelligent load balancing.
Args:
skip_index: If provided, skip this specific key index and return next healthy key
Returns:
API key string, or None if no keys available
"""
if not self.groq_keys:
return None
if skip_index is not None:
# Skip the provided index and find next healthy key
self.groq_index = self._get_next_healthy_key_index(skip_index + 1)
logger.info(f"⟲ Skipped Groq key {skip_index}, using key {self.groq_index}")
else:
# Normal round-robin to next healthy key
self.groq_index = self._get_next_healthy_key_index(self.groq_index)
key = self.groq_keys[self.groq_index]
# Advance index for next call
next_index = (self.groq_index + 1) % len(self.groq_keys)
self.groq_index = self._get_next_healthy_key_index(next_index)
return key
def track_groq_call(self, key: str, success: bool = True):
"""
Track an actual Groq API call (not just key retrieval).
Call this AFTER making an actual API request.
Args:
key: The API key that was used
success: Whether the call succeeded
"""
try:
key_index = self.groq_keys.index(key)
key_idx_str = str(key_index)
self.state["groq_call_counts"][key_idx_str] = self.state["groq_call_counts"].get(key_idx_str, 0) + 1
self.state["groq_total_calls"] += 1
if success:
# Reset failures on success
self.state["groq_failure_counts"][key_idx_str] = 0
else:
# Increment failures
self.state["groq_failure_counts"][key_idx_str] = self.state["groq_failure_counts"].get(key_idx_str, 0) + 1
self.state["groq_total_failures"] += 1
self._save_state()
except ValueError:
logger.warning("Call tracked for unknown Groq key")
def mark_groq_key_failure(self, key: str):
"""Record a failure for a Groq key (DEPRECATED — use track_groq_call instead)."""
self.track_groq_call(key, success=False)
def mark_groq_key_success(self, key: str):
"""Clear failures for a Groq key after successful use (DEPRECATED — use track_groq_call instead)."""
self.track_groq_call(key, success=True)
def get_groq_key_count(self) -> int:
"""Return total number of Groq keys."""
return len(self.groq_keys)
def get_groq_stats(self) -> Dict[str, Any]:
"""Get detailed statistics about Groq key usage."""
stats = {
"total_keys": len(self.groq_keys),
"total_calls": self.state["groq_total_calls"],
"total_failures": self.state["groq_total_failures"],
"current_index": self.groq_index,
"per_key": {}
}
for i in range(len(self.groq_keys)):
idx_str = str(i)
stats["per_key"][i] = {
"calls": self.state["groq_call_counts"].get(idx_str, 0),
"failures": self.state["groq_failure_counts"].get(idx_str, 0),
"healthy": self._is_key_healthy(i),
}
return stats
def is_groq_available(self) -> bool:
"""Check if any Groq keys are available."""
return len(self.groq_keys) > 0
# ================================================== HuggingFace
def get_hf_token(self) -> Optional[str]:
"""Get HuggingFace token."""
return self.hf_token
def is_hf_available(self) -> bool:
"""HF Inference API works without a token (rate-limited)."""
return True
# ================================================== Legacy stubs
def get_together_key(self, skip_index: Optional[int] = None) -> Optional[str]:
"""Stub — Together AI removed. Returns None."""
return None
def get_together_key_count(self) -> int:
return 0
def is_together_available(self) -> bool:
return False
def get_gemini_key(self, skip_index: Optional[int] = None) -> Optional[str]:
"""Stub — Gemini removed. Returns None."""
return None
def get_gemini_key_count(self) -> int:
return 0
def is_gemini_available(self) -> bool:
return False
# Global singleton
_api_manager = APIKeyManager()
def get_api_manager() -> APIKeyManager:
return _api_manager
# Convenience wrappers
def get_next_groq_key(skip_index: Optional[int] = None) -> Optional[str]:
return get_api_manager().get_groq_key(skip_index)
def get_hf_token() -> Optional[str]:
return get_api_manager().get_hf_token()
# Legacy stubs (kept so old imports don't break)
def get_next_together_key(skip_index: Optional[int] = None) -> Optional[str]:
return None
def get_next_gemini_key(skip_index: Optional[int] = None) -> Optional[str]:
return None