Spaces:
Paused
Paused
| from threading import Lock | |
| import os | |
| from typing import List, Optional, Literal, Union, Dict | |
| from dotenv import load_dotenv | |
| import re | |
| from langchain_xai import ChatXAI | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from functools import wraps | |
| import time | |
| from openai import RateLimitError, OpenAIError | |
| from anthropic import RateLimitError as AnthropicRateLimitError, APIError as AnthropicAPIError | |
| from google.api_core.exceptions import ResourceExhausted, BadRequest, InvalidArgument | |
| from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_exception_type | |
| import asyncio | |
| # Tell Pydantic to finish wiring up this dynamic model | |
| ChatGoogleGenerativeAI.model_rebuild() | |
| ChatAnthropic.model_rebuild() | |
| ChatOpenAI.model_rebuild() | |
| ChatXAI.model_rebuild() | |
| ModelProvider = Literal["openai", "anthropic", "google", "xai"] | |
| class APIKeyManager: | |
| _instance = None | |
| _lock = Lock() | |
| # Define supported models | |
| SUPPORTED_MODELS = { | |
| "openai": [ | |
| "o1-mini", | |
| "o1", | |
| "o1-pro", | |
| "o3-mini", | |
| "o3", | |
| "o4-mini", | |
| "gpt-4o-mini-2024-07-18", | |
| "gpt-4o-mini", | |
| "chatgpt-4o-latest", | |
| "gpt-4o-2024-05-13", | |
| "gpt-4o-2024-08-06", | |
| "gpt-4o-2024-11-20", | |
| "gpt-4o", | |
| "gpt-4.1-nano", | |
| "gpt-4.1-mini", | |
| "gpt-4.1" | |
| ], | |
| "google": [ | |
| "gemini-2.0-pro-exp-02-05", | |
| "gemini-2.0-flash-lite-preview-02-05", | |
| "gemini-2.0-flash-exp", | |
| "gemini-2.0-flash", | |
| "gemini-2.0-flash-thinking-exp-1219", | |
| "gemini-2.5-flash-lite-preview-06-17", | |
| "gemini-2.5-flash-preview-04-17", | |
| "gemini-2.5-flash", | |
| "gemini-2.5-pro" | |
| ], | |
| "xai": [ | |
| "grok-2", | |
| "grok-3-mini-latest", | |
| "grok-3-mini-fast-latest", | |
| "grok-3-latest", | |
| "grok-3-fast-latest" | |
| ], | |
| "anthropic": [ | |
| "claude-opus-4-20250514", | |
| "claude-sonnet-4-20250514", | |
| "claude-3-7-sonnet-20250219", | |
| "claude-3-5-sonnet-20241022", | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-5-haiku-20241022", | |
| "claude-3-5-haiku-latest", | |
| "claude-3-opus-20240229", | |
| "claude-3-opus-latest", | |
| "claude-3-sonnet-20240229", | |
| "claude-3-haiku-20240307" | |
| ] | |
| } | |
| def __new__(cls): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super(APIKeyManager, cls).__new__(cls) | |
| cls._instance._initialized = False | |
| return cls._instance | |
| def __init__(self): | |
| if not self._initialized: | |
| self._initialized = True | |
| # 1) Always load env | |
| load_dotenv(override=True) | |
| self._current_indices = { | |
| "openai": 0, | |
| "anthropic": 0, | |
| "google": 0, | |
| "xai": 0 | |
| } | |
| self._lock = Lock() | |
| # 2) load all provider keys from environment | |
| self._api_keys = self._load_api_keys() | |
| self._llm = None | |
| self._current_provider = None | |
| # 3) read user’s chosen provider, model, temperature, top_p from env | |
| provider_env = os.getenv("MODEL_PROVIDER", "openai").strip().lower() | |
| self.model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo").strip() | |
| temp_str = os.getenv("MODEL_TEMPERATURE", "0") | |
| topp_str = os.getenv("MODEL_TOP_P", "1") | |
| try: | |
| self.temperature = float(temp_str) | |
| except ValueError: | |
| self.temperature = 0.0 | |
| try: | |
| self.top_p = float(topp_str) | |
| except ValueError: | |
| self.top_p = 1.0 | |
| def _reinit(self): | |
| self._initialized = False | |
| self.__init__() | |
| def _load_api_keys(self) -> Dict[str, List[str]]: | |
| """Load API keys from environment variables dynamically.""" | |
| api_keys = { | |
| "openai": [], | |
| "anthropic": [], | |
| "google": [], | |
| "xai": [] | |
| } | |
| # Get all environment variables | |
| env_vars = dict(os.environ) | |
| # Load OpenAI API keys | |
| openai_pattern = re.compile(r'OPENAI_API_KEY_\d+$') | |
| openai_keys = {k: v for k, v in env_vars.items() if openai_pattern.match(k) and v.strip()} | |
| if not openai_keys: | |
| default_key = os.getenv('OPENAI_API_KEY') | |
| if default_key and default_key.strip(): | |
| api_keys["openai"].append(default_key) | |
| else: | |
| sorted_keys = sorted(openai_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
| for key_name in sorted_keys: | |
| api_key = openai_keys[key_name] | |
| if api_key and api_key.strip(): | |
| api_keys["openai"].append(api_key) | |
| # Load Google API keys | |
| google_pattern = re.compile(r'GOOGLE_API_KEY_\d+$') | |
| google_keys = {k: v for k, v in env_vars.items() if google_pattern.match(k) and v.strip()} | |
| if not google_keys: | |
| default_key = os.getenv('GOOGLE_API_KEY') | |
| if default_key and default_key.strip(): | |
| api_keys["google"].append(default_key) | |
| else: | |
| sorted_keys = sorted(google_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
| for key_name in sorted_keys: | |
| api_key = google_keys[key_name] | |
| if api_key and api_key.strip(): | |
| api_keys["google"].append(api_key) | |
| # Load XAI API keys | |
| xai_pattern = re.compile(r'XAI_API_KEY_\d+$') | |
| xai_keys = {k: v for k, v in env_vars.items() if xai_pattern.match(k) and v.strip()} | |
| if not xai_keys: | |
| default_key = os.getenv('XAI_API_KEY') | |
| if default_key and default_key.strip(): | |
| api_keys["xai"].append(default_key) | |
| else: | |
| sorted_keys = sorted(xai_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
| for key_name in sorted_keys: | |
| api_key = xai_keys[key_name] | |
| if api_key and api_key.strip(): | |
| api_keys["xai"].append(api_key) | |
| # Load Anthropic API keys | |
| anthropic_pattern = re.compile(r'ANTHROPIC_API_KEY_\d+$') | |
| anthropic_keys = {k: v for k, v in env_vars.items() if anthropic_pattern.match(k) and v.strip()} | |
| if not anthropic_keys: | |
| default_key = os.getenv('ANTHROPIC_API_KEY') | |
| if default_key and default_key.strip(): | |
| api_keys["anthropic"].append(default_key) | |
| else: | |
| sorted_keys = sorted(anthropic_keys.keys(), key=lambda x: int(x.split('_')[-1])) | |
| for key_name in sorted_keys: | |
| api_key = anthropic_keys[key_name] | |
| if api_key and api_key.strip(): | |
| api_keys["anthropic"].append(api_key) | |
| if not any(api_keys.values()): | |
| raise Exception("No valid API keys found in environment variables") | |
| for provider, keys in api_keys.items(): | |
| if keys: | |
| print(f"Loaded {len(keys)} {provider} API keys for rotation") | |
| return api_keys | |
| def get_next_api_key(self, provider: ModelProvider) -> str: | |
| """Get the next API key in round-robin fashion for the specified provider.""" | |
| with self._lock: | |
| if not self._api_keys.get(provider) or len(self._api_keys[provider]) == 0: | |
| raise Exception(f"No API key found for {provider}") | |
| if provider not in self._current_indices: | |
| self._current_indices[provider] = 0 | |
| current_key = self._api_keys[provider][self._current_indices[provider]] | |
| self._current_indices[provider] = (self._current_indices[provider] + 1) % len(self._api_keys[provider]) | |
| return current_key | |
| def _get_provider_for_model(self) -> ModelProvider: | |
| """Determine the provider based on the model name.""" | |
| load_dotenv(override=True) # to refresh in case .env changed | |
| provider_env = os.getenv("MODEL_PROVIDER", "openai").lower().strip() | |
| if provider_env not in self.SUPPORTED_MODELS: | |
| raise Exception( | |
| f"Invalid or missing MODEL_PROVIDER in env: '{provider_env}'. " | |
| f"Must be one of: {list(self.SUPPORTED_MODELS.keys())}" | |
| ) | |
| # check if user-chosen model is in that provider’s list | |
| if self.model_name not in self.SUPPORTED_MODELS[provider_env]: | |
| available = self.SUPPORTED_MODELS[provider_env] | |
| raise Exception( | |
| f"Model '{self.model_name}' is not available under provider '{provider_env}'. " | |
| f"Available: {available}" | |
| ) | |
| return provider_env | |
| def _initialize_llm( | |
| self, | |
| model_name: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| streaming: bool = False | |
| ): | |
| """Initialize LLM with the next API key in rotation.""" | |
| load_dotenv(override=True) # refresh .env in case it changed | |
| provider = self._get_provider_for_model() | |
| model_name = model_name if model_name else self.model_name | |
| temperature = temperature if temperature else self.temperature | |
| top_p = top_p if top_p else self.top_p | |
| api_key = self.get_next_api_key(provider) | |
| print(f"Using provider={provider}, model_name={model_name}, " | |
| f"temperature={temperature}, top_p={top_p}, key={api_key}") | |
| kwargs = { | |
| "model": model_name, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_retries": 0, | |
| "streaming": streaming, | |
| "api_key": api_key | |
| } | |
| if max_tokens is not None: | |
| kwargs["max_tokens"] = max_tokens | |
| if provider == "openai": | |
| self._llm = ChatOpenAI(**kwargs) | |
| elif provider == "google": | |
| self._llm = ChatGoogleGenerativeAI(**kwargs) | |
| elif provider == "anthropic": | |
| self._llm = ChatAnthropic(**kwargs) | |
| else: | |
| self._llm = ChatXAI(**kwargs) | |
| self._current_provider = provider | |
| def get_llm( | |
| self, | |
| model_name: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| streaming: bool = False | |
| ) -> Union[ChatOpenAI, ChatGoogleGenerativeAI, ChatAnthropic, ChatXAI]: | |
| """Get LLM instance with the current API key.""" | |
| provider = self._get_provider_for_model() | |
| model_name = model_name if model_name else self.model_name | |
| temperature = temperature if temperature else self.temperature | |
| top_p = top_p if top_p else self.top_p | |
| if self._llm is None or provider != self._current_provider: | |
| self._initialize_llm(model_name, temperature, top_p, max_tokens, streaming) | |
| return self._llm | |
| def rotate_key(self, provider: Optional[ModelProvider] = None, streaming: bool = False) -> None: | |
| """Manually rotate to the next API key.""" | |
| if provider is None: | |
| provider = self._current_provider | |
| self._initialize_llm(streaming=streaming) | |
| def get_all_api_keys(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, List[str]], List[str]]: | |
| """Get all available API keys.""" | |
| if provider: | |
| return self._api_keys[provider].copy() | |
| return {k: v.copy() for k, v in self._api_keys.items()} | |
| def get_key_count(self, provider: Optional[ModelProvider] = None) -> Union[Dict[str, int], int]: | |
| """Get the total number of available API keys.""" | |
| if provider: | |
| return len(self._api_keys[provider]) | |
| return {k: len(v) for k, v in self._api_keys.items()} | |
| def __len__(self) -> Dict[str, int]: | |
| """Get the number of active API keys for each provider.""" | |
| return self.get_key_count() | |
| def __bool__(self) -> bool: | |
| """Check if there are any API keys available.""" | |
| return any(bool(keys) for keys in self._api_keys.values()) | |
| def with_api_manager( | |
| model_name: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| max_tokens: Optional[int] = None, | |
| streaming: bool = False, | |
| delay_on_timeout: int = 20, | |
| max_token_reduction_attempts: int = 0 | |
| ): | |
| """Decorator for automatic key rotation on error with delay on timeout.""" | |
| manager = APIKeyManager() | |
| provider = manager._get_provider_for_model() | |
| model_name = model_name if model_name else manager.model_name | |
| temperature = temperature if temperature else manager.temperature | |
| top_p = top_p if top_p else manager.top_p | |
| key_count = manager.get_key_count(provider) | |
| def decorator(func): | |
| if asyncio.iscoroutinefunction(func): | |
| async def wrapper(*args, **kwargs): | |
| if key_count > 1: | |
| all_keys = manager.get_all_api_keys(provider) | |
| tried_keys = set() | |
| current_max_tokens = max_tokens | |
| token_reduction_attempts = 0 | |
| while len(tried_keys) < len(all_keys): | |
| try: | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=current_max_tokens, | |
| streaming=streaming | |
| ) | |
| result = await func(*args, **kwargs, llm=llm) | |
| return result | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
| current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
| print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") | |
| tried_keys.add(current_key) | |
| if len(tried_keys) < len(all_keys): | |
| manager.rotate_key(provider=provider, streaming=streaming) | |
| print(f"Using next available {provider} API key") | |
| else: | |
| if delay_on_timeout > 0: | |
| print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") | |
| time.sleep(delay_on_timeout) | |
| manager._current_indices[provider] = 0 | |
| else: | |
| print(f"All {provider} API keys failed due to rate limits: {str(e)}") | |
| raise | |
| except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| error_str = str(e) | |
| if "token" in error_str.lower() or "context length" in error_str.lower(): | |
| print(f"Token limit error encountered: {error_str}") | |
| if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: | |
| current_max_tokens = int(current_max_tokens * 0.8) # Reduce the local variable | |
| token_reduction_attempts += 1 | |
| print(f"Retrying with reduced max_tokens: {current_max_tokens}") | |
| continue # Retry with reduced max_tokens | |
| else: | |
| print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") | |
| current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
| tried_keys.add(current_key) | |
| if len(tried_keys) < len(all_keys): | |
| manager.rotate_key(provider=provider, streaming=streaming) | |
| print(f"Using next available {provider} API key after token limit error.") | |
| else: | |
| raise # All keys tried, raise the token limit error | |
| else: | |
| # Re-raise other API errors | |
| raise | |
| # Attempt one final time after trying all keys (for rate limits with delay) | |
| try: | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=current_max_tokens, # Use the current value | |
| streaming=streaming | |
| ) | |
| result = await func(*args, **kwargs, llm=llm) | |
| return result | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
| OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| print(f"Error after retrying all {provider} API keys: {str(e)}") | |
| raise | |
| elif key_count == 1: | |
| async def attempt_function_call(): | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| streaming=streaming | |
| ) | |
| return await func(*args, **kwargs, llm=llm) | |
| try: | |
| return await attempt_function_call() | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
| OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| print(f"Error encountered for {provider} after multiple retries: {str(e)}") | |
| raise | |
| else: | |
| print(f"No API keys found for provider: {provider}") | |
| raise | |
| else: | |
| def wrapper(*args, **kwargs): | |
| if key_count > 1: | |
| all_keys = manager.get_all_api_keys(provider) | |
| tried_keys = set() | |
| current_max_tokens = max_tokens | |
| token_reduction_attempts = 0 | |
| while len(tried_keys) < len(all_keys): | |
| try: | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=current_max_tokens, | |
| streaming=streaming | |
| ) | |
| result = func(*args, **kwargs, llm=llm) | |
| return result | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
| current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
| print(f"Rate limit error with {provider} API key {current_key}: {str(e)}") | |
| tried_keys.add(current_key) | |
| if len(tried_keys) < len(all_keys): | |
| manager.rotate_key(provider=provider, streaming=streaming) | |
| print(f"Using next available {provider} API key") | |
| else: | |
| if delay_on_timeout > 0: | |
| print(f"Waiting for {delay_on_timeout} seconds before retrying with the first key...") | |
| time.sleep(delay_on_timeout) | |
| manager._current_indices[provider] = 0 | |
| else: | |
| print(f"All {provider} API keys failed due to rate limits: {str(e)}") | |
| raise | |
| except (OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| error_str = str(e) | |
| if "token" in error_str.lower() or "context length" in error_str.lower(): | |
| print(f"Token limit error encountered: {error_str}") | |
| if max_token_reduction_attempts > 0 and max_tokens is not None and token_reduction_attempts < max_token_reduction_attempts: | |
| current_max_tokens = int(current_max_tokens * 0.8) | |
| token_reduction_attempts += 1 | |
| print(f"Retrying with reduced max_tokens: {current_max_tokens}") | |
| continue # Retry with reduced max_tokens | |
| else: | |
| print("Max token reduction attempts reached or token reduction disabled. Proceeding with key rotation.") | |
| current_key = manager._api_keys[provider][(manager._current_indices[provider] - 1) % len(all_keys)] | |
| tried_keys.add(current_key) | |
| if len(tried_keys) < len(all_keys): | |
| manager.rotate_key(provider=provider, streaming=streaming) | |
| print(f"Using next available {provider} API key after token limit error.") | |
| else: | |
| raise # All keys tried, raise the token limit error | |
| else: | |
| # Re-raise other API errors | |
| raise | |
| # Attempt one final time after trying all keys (for rate limits with delay) | |
| try: | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=current_max_tokens, | |
| streaming=streaming | |
| ) | |
| result = func(*args, **kwargs, llm=llm) | |
| return result | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
| OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| print(f"Error after retrying all {provider} API keys: {str(e)}") | |
| raise | |
| elif key_count == 1: | |
| def attempt_function_call(): | |
| llm = manager.get_llm( | |
| model_name=model_name, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| streaming=streaming | |
| ) | |
| return func(*args, **kwargs, llm=llm) | |
| try: | |
| return attempt_function_call() | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError, | |
| OpenAIError, AnthropicAPIError, BadRequest, InvalidArgument) as e: | |
| print(f"Error encountered for {provider} after multiple retries: {str(e)}") | |
| raise | |
| else: | |
| print(f"No API keys found for provider: {provider}") | |
| raise | |
| return wrapper | |
| return decorator | |
| if __name__ == "__main__": | |
| import asyncio | |
| prompt = "What is the capital of France?" | |
| # Test key rotation | |
| async def test_load_balancing(prompt: str, test_count: int = 10, stream: bool = False): | |
| async def test(prompt: str, test_count: int = 10, *, llm): | |
| print("="*50) | |
| for i in range(test_count): | |
| try: | |
| print(f"\nTest {i+1} of {test_count}") | |
| if stream: | |
| async for chunk in llm.astream(prompt): | |
| print(chunk.content, end="", flush=True) | |
| print("\n" + "-"*50 if i != test_count - 1 else "\n" + "="*50) | |
| else: | |
| response = await llm.ainvoke(prompt) | |
| print(f"Response: {response.content.strip()}") | |
| print("-"*50) if i != test_count - 1 else print("="*50) | |
| except (RateLimitError, ResourceExhausted, AnthropicRateLimitError) as e: | |
| print(f"Error: {str(e)}") | |
| raise | |
| await test(prompt, test_count=test_count) | |
| # Test without load balancing | |
| def test_without_load_balancing(model_name: str, prompt: str, test_count: int = 10): | |
| manager = APIKeyManager() | |
| print(f"Using model: {model_name}") | |
| print("="*50) | |
| i = 0 | |
| while i < test_count: | |
| try: | |
| print(f"Test {i+1} of {test_count}") | |
| llm = manager.get_llm(model_name=model_name) | |
| response = llm.invoke(prompt) | |
| print(f"Response: {response.content.strip()}") | |
| print("-"*50) if i != test_count - 1 else print("="*50) | |
| i += 1 | |
| except Exception as e: | |
| raise Exception(f"Error with {model_name}: {str(e)}") | |
| test_without_load_balancing(model_name="gemini-2.5-flash-lite-preview-06-17", prompt=prompt, test_count=50) | |
| # asyncio.run(test_load_balancing(prompt=prompt, test_count=100, stream=True)) |