Spaces:
Paused
Paused
| import asyncio | |
| import fnmatch | |
| import json | |
| import re | |
| import codecs | |
| import time | |
| import os | |
| import random | |
| import httpx | |
| import litellm | |
| from litellm.exceptions import APIConnectionError | |
| from litellm.litellm_core_utils.token_counter import token_counter | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Dict, Any, AsyncGenerator, Optional, Union | |
| lib_logger = logging.getLogger("rotator_library") | |
| # Ensure the logger is configured to propagate to the root logger | |
| # which is set up in main.py. This allows the main app to control | |
| # log levels and handlers centrally. | |
| lib_logger.propagate = False | |
| from .usage_manager import UsageManager | |
| from .failure_logger import log_failure, configure_failure_logger | |
| from .error_handler import ( | |
| PreRequestCallbackError, | |
| CredentialNeedsReauthError, | |
| classify_error, | |
| AllProviders, | |
| NoAvailableKeysError, | |
| should_rotate_on_error, | |
| should_retry_same_key, | |
| RequestErrorAccumulator, | |
| mask_credential, | |
| ) | |
| from .providers import PROVIDER_PLUGINS | |
| from .providers.openai_compatible_provider import OpenAICompatibleProvider | |
| from .request_sanitizer import sanitize_request_payload | |
| from .cooldown_manager import CooldownManager | |
| from .credential_manager import CredentialManager | |
| from .background_refresher import BackgroundRefresher | |
| from .model_definitions import ModelDefinitions | |
| from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file | |
| class StreamedAPIError(Exception): | |
| """Custom exception to signal an API error received over a stream.""" | |
| def __init__(self, message, data=None): | |
| super().__init__(message) | |
| self.data = data | |
| class RotatingClient: | |
| """ | |
| A client that intelligently rotates and retries API keys using LiteLLM, | |
| with support for both streaming and non-streaming responses. | |
| """ | |
| def __init__( | |
| self, | |
| api_keys: Optional[Dict[str, List[str]]] = None, | |
| oauth_credentials: Optional[Dict[str, List[str]]] = None, | |
| max_retries: int = 2, | |
| usage_file_path: Optional[Union[str, Path]] = None, | |
| configure_logging: bool = True, | |
| global_timeout: int = 30, | |
| abort_on_callback_error: bool = True, | |
| litellm_provider_params: Optional[Dict[str, Any]] = None, | |
| ignore_models: Optional[Dict[str, List[str]]] = None, | |
| whitelist_models: Optional[Dict[str, List[str]]] = None, | |
| enable_request_logging: bool = False, | |
| max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, | |
| rotation_tolerance: float = 3.0, | |
| data_dir: Optional[Union[str, Path]] = None, | |
| ): | |
| """ | |
| Initialize the RotatingClient with intelligent credential rotation. | |
| Args: | |
| api_keys: Dictionary mapping provider names to lists of API keys | |
| oauth_credentials: Dictionary mapping provider names to OAuth credential paths | |
| max_retries: Maximum number of retry attempts per credential | |
| usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json | |
| configure_logging: Whether to configure library logging | |
| global_timeout: Global timeout for requests in seconds | |
| abort_on_callback_error: Whether to abort on pre-request callback errors | |
| litellm_provider_params: Provider-specific parameters for LiteLLM | |
| ignore_models: Models to ignore/blacklist per provider | |
| whitelist_models: Models to explicitly whitelist per provider | |
| enable_request_logging: Whether to enable detailed request logging | |
| max_concurrent_requests_per_key: Max concurrent requests per key by provider | |
| rotation_tolerance: Tolerance for weighted random credential rotation. | |
| - 0.0: Deterministic, least-used credential always selected | |
| - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max | |
| - 5.0+: High randomness, more unpredictable selection patterns | |
| data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json). | |
| If None, auto-detects: EXE directory if frozen, else current working directory. | |
| """ | |
| # Resolve data_dir early - this becomes the root for all file operations | |
| if data_dir is not None: | |
| self.data_dir = Path(data_dir).resolve() | |
| else: | |
| self.data_dir = get_default_root() | |
| # Configure failure logger to use correct logs directory | |
| configure_failure_logger(get_logs_dir(self.data_dir)) | |
| os.environ["LITELLM_LOG"] = "ERROR" | |
| litellm.set_verbose = False | |
| litellm.drop_params = True | |
| if configure_logging: | |
| # When True, this allows logs from this library to be handled | |
| # by the parent application's logging configuration. | |
| lib_logger.propagate = True | |
| # Remove any default handlers to prevent duplicate logging | |
| if lib_logger.hasHandlers(): | |
| lib_logger.handlers.clear() | |
| lib_logger.addHandler(logging.NullHandler()) | |
| else: | |
| lib_logger.propagate = False | |
| api_keys = api_keys or {} | |
| oauth_credentials = oauth_credentials or {} | |
| # Filter out providers with empty lists of credentials to ensure validity | |
| api_keys = {provider: keys for provider, keys in api_keys.items() if keys} | |
| oauth_credentials = { | |
| provider: paths for provider, paths in oauth_credentials.items() if paths | |
| } | |
| if not api_keys and not oauth_credentials: | |
| lib_logger.warning( | |
| "No provider credentials configured. The client will be unable to make any API requests." | |
| ) | |
| self.api_keys = api_keys | |
| # Use provided oauth_credentials directly if available (already discovered by main.py) | |
| # Only call discover_and_prepare() if no credentials were passed | |
| if oauth_credentials: | |
| self.oauth_credentials = oauth_credentials | |
| else: | |
| self.credential_manager = CredentialManager( | |
| os.environ, oauth_dir=get_oauth_dir(self.data_dir) | |
| ) | |
| self.oauth_credentials = self.credential_manager.discover_and_prepare() | |
| self.background_refresher = BackgroundRefresher(self) | |
| self.oauth_providers = set(self.oauth_credentials.keys()) | |
| all_credentials = {} | |
| for provider, keys in api_keys.items(): | |
| all_credentials.setdefault(provider, []).extend(keys) | |
| for provider, paths in self.oauth_credentials.items(): | |
| all_credentials.setdefault(provider, []).extend(paths) | |
| self.all_credentials = all_credentials | |
| self.max_retries = max_retries | |
| self.global_timeout = global_timeout | |
| self.abort_on_callback_error = abort_on_callback_error | |
| # Initialize provider plugins early so they can be used for rotation mode detection | |
| self._provider_plugins = PROVIDER_PLUGINS | |
| self._provider_instances = {} | |
| # Build provider rotation modes map | |
| # Each provider can specify its preferred rotation mode ("balanced" or "sequential") | |
| provider_rotation_modes = {} | |
| for provider in self.all_credentials.keys(): | |
| provider_class = self._provider_plugins.get(provider) | |
| if provider_class and hasattr(provider_class, "get_rotation_mode"): | |
| # Use class method to get rotation mode (checks env var + class default) | |
| mode = provider_class.get_rotation_mode(provider) | |
| else: | |
| # Fallback: check environment variable directly | |
| env_key = f"ROTATION_MODE_{provider.upper()}" | |
| mode = os.getenv(env_key, "balanced") | |
| provider_rotation_modes[provider] = mode | |
| if mode != "balanced": | |
| lib_logger.info(f"Provider '{provider}' using rotation mode: {mode}") | |
| # Build priority-based concurrency multiplier maps | |
| # These are universal multipliers based on credential tier/priority | |
| priority_multipliers: Dict[str, Dict[int, int]] = {} | |
| priority_multipliers_by_mode: Dict[str, Dict[str, Dict[int, int]]] = {} | |
| sequential_fallback_multipliers: Dict[str, int] = {} | |
| for provider in self.all_credentials.keys(): | |
| provider_class = self._provider_plugins.get(provider) | |
| # Start with provider class defaults | |
| if provider_class: | |
| # Get default priority multipliers from provider class | |
| if hasattr(provider_class, "default_priority_multipliers"): | |
| default_multipliers = provider_class.default_priority_multipliers | |
| if default_multipliers: | |
| priority_multipliers[provider] = dict(default_multipliers) | |
| # Get sequential fallback from provider class | |
| if hasattr(provider_class, "default_sequential_fallback_multiplier"): | |
| fallback = provider_class.default_sequential_fallback_multiplier | |
| if fallback != 1: # Only store if different from global default | |
| sequential_fallback_multipliers[provider] = fallback | |
| # Override with environment variables | |
| # Format: CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>=<multiplier> | |
| # Format: CONCURRENCY_MULTIPLIER_<PROVIDER>_PRIORITY_<N>_<MODE>=<multiplier> | |
| for key, value in os.environ.items(): | |
| prefix = f"CONCURRENCY_MULTIPLIER_{provider.upper()}_PRIORITY_" | |
| if key.startswith(prefix): | |
| remainder = key[len(prefix) :] | |
| try: | |
| multiplier = int(value) | |
| if multiplier < 1: | |
| lib_logger.warning(f"Invalid {key}: {value}. Must be >= 1.") | |
| continue | |
| # Check if mode-specific (e.g., _PRIORITY_1_SEQUENTIAL) | |
| if "_" in remainder: | |
| parts = remainder.rsplit("_", 1) | |
| priority = int(parts[0]) | |
| mode = parts[1].lower() | |
| if mode in ("sequential", "balanced"): | |
| # Mode-specific override | |
| if provider not in priority_multipliers_by_mode: | |
| priority_multipliers_by_mode[provider] = {} | |
| if mode not in priority_multipliers_by_mode[provider]: | |
| priority_multipliers_by_mode[provider][mode] = {} | |
| priority_multipliers_by_mode[provider][mode][ | |
| priority | |
| ] = multiplier | |
| lib_logger.info( | |
| f"Provider '{provider}' priority {priority} ({mode} mode) multiplier: {multiplier}x" | |
| ) | |
| else: | |
| # Assume it's part of the priority number (unlikely but handle gracefully) | |
| lib_logger.warning(f"Unknown mode in {key}: {mode}") | |
| else: | |
| # Universal priority multiplier | |
| priority = int(remainder) | |
| if provider not in priority_multipliers: | |
| priority_multipliers[provider] = {} | |
| priority_multipliers[provider][priority] = multiplier | |
| lib_logger.info( | |
| f"Provider '{provider}' priority {priority} multiplier: {multiplier}x" | |
| ) | |
| except ValueError: | |
| lib_logger.warning( | |
| f"Invalid {key}: {value}. Could not parse priority or multiplier." | |
| ) | |
| # Log configured multipliers | |
| for provider, multipliers in priority_multipliers.items(): | |
| if multipliers: | |
| lib_logger.info( | |
| f"Provider '{provider}' priority multipliers: {multipliers}" | |
| ) | |
| for provider, fallback in sequential_fallback_multipliers.items(): | |
| lib_logger.info( | |
| f"Provider '{provider}' sequential fallback multiplier: {fallback}x" | |
| ) | |
| # Resolve usage file path - use provided path or default to data_dir | |
| if usage_file_path is not None: | |
| resolved_usage_path = Path(usage_file_path) | |
| else: | |
| resolved_usage_path = self.data_dir / "key_usage.json" | |
| self.usage_manager = UsageManager( | |
| file_path=resolved_usage_path, | |
| rotation_tolerance=rotation_tolerance, | |
| provider_rotation_modes=provider_rotation_modes, | |
| provider_plugins=PROVIDER_PLUGINS, | |
| priority_multipliers=priority_multipliers, | |
| priority_multipliers_by_mode=priority_multipliers_by_mode, | |
| sequential_fallback_multipliers=sequential_fallback_multipliers, | |
| ) | |
| self._model_list_cache = {} | |
| self.http_client = httpx.AsyncClient() | |
| self.all_providers = AllProviders() | |
| self.cooldown_manager = CooldownManager() | |
| self.litellm_provider_params = litellm_provider_params or {} | |
| self.ignore_models = ignore_models or {} | |
| self.whitelist_models = whitelist_models or {} | |
| self.enable_request_logging = enable_request_logging | |
| self.model_definitions = ModelDefinitions() | |
| # Store and validate max concurrent requests per key | |
| self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} | |
| # Validate all values are >= 1 | |
| for provider, max_val in self.max_concurrent_requests_per_key.items(): | |
| if max_val < 1: | |
| lib_logger.warning( | |
| f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." | |
| ) | |
| self.max_concurrent_requests_per_key[provider] = 1 | |
| def _is_model_ignored(self, provider: str, model_id: str) -> bool: | |
| """ | |
| Checks if a model should be ignored based on the ignore list. | |
| Supports full glob/fnmatch patterns for both full model IDs and model names. | |
| Pattern examples: | |
| - "gpt-4" - exact match | |
| - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) | |
| - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) | |
| - "*-preview*" - contains wildcard (matches anything with -preview) | |
| - "*" - match all | |
| """ | |
| model_provider = model_id.split("/")[0] | |
| if model_provider not in self.ignore_models: | |
| return False | |
| ignore_list = self.ignore_models[model_provider] | |
| if ignore_list == ["*"]: | |
| return True | |
| try: | |
| # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") | |
| provider_model_name = model_id.split("/", 1)[1] | |
| except IndexError: | |
| provider_model_name = model_id | |
| for ignored_pattern in ignore_list: | |
| # Use fnmatch for full glob pattern support | |
| if fnmatch.fnmatch(provider_model_name, ignored_pattern) or fnmatch.fnmatch( | |
| model_id, ignored_pattern | |
| ): | |
| return True | |
| return False | |
| def _is_model_whitelisted(self, provider: str, model_id: str) -> bool: | |
| """ | |
| Checks if a model is explicitly whitelisted. | |
| Supports full glob/fnmatch patterns for both full model IDs and model names. | |
| Pattern examples: | |
| - "gpt-4" - exact match | |
| - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) | |
| - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) | |
| - "*-preview*" - contains wildcard (matches anything with -preview) | |
| - "*" - match all | |
| """ | |
| model_provider = model_id.split("/")[0] | |
| if model_provider not in self.whitelist_models: | |
| return False | |
| whitelist = self.whitelist_models[model_provider] | |
| try: | |
| # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") | |
| provider_model_name = model_id.split("/", 1)[1] | |
| except IndexError: | |
| provider_model_name = model_id | |
| for whitelisted_pattern in whitelist: | |
| # Use fnmatch for full glob pattern support | |
| if fnmatch.fnmatch( | |
| provider_model_name, whitelisted_pattern | |
| ) or fnmatch.fnmatch(model_id, whitelisted_pattern): | |
| return True | |
| return False | |
| def _sanitize_litellm_log(self, log_data: dict) -> dict: | |
| """ | |
| Recursively removes large data fields and sensitive information from litellm log | |
| dictionaries to keep debug logs clean and secure. | |
| """ | |
| if not isinstance(log_data, dict): | |
| return log_data | |
| # Keys to remove at any level of the dictionary | |
| keys_to_pop = [ | |
| "messages", | |
| "input", | |
| "response", | |
| "data", | |
| "api_key", | |
| "api_base", | |
| "original_response", | |
| "additional_args", | |
| ] | |
| # Keys that might contain nested dictionaries to clean | |
| nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] | |
| # Create a deep copy to avoid modifying the original log object in memory | |
| clean_data = json.loads(json.dumps(log_data, default=str)) | |
| def clean_recursively(data_dict): | |
| if not isinstance(data_dict, dict): | |
| return | |
| # Remove sensitive/large keys | |
| for key in keys_to_pop: | |
| data_dict.pop(key, None) | |
| # Recursively clean nested dictionaries | |
| for key in nested_keys: | |
| if key in data_dict and isinstance(data_dict[key], dict): | |
| clean_recursively(data_dict[key]) | |
| # Also iterate through all values to find any other nested dicts | |
| for key, value in list(data_dict.items()): | |
| if isinstance(value, dict): | |
| clean_recursively(value) | |
| clean_recursively(clean_data) | |
| return clean_data | |
| def _litellm_logger_callback(self, log_data: dict): | |
| """ | |
| Callback function to redirect litellm's logs to the library's logger. | |
| This allows us to control the log level and destination of litellm's output. | |
| It also cleans up error logs for better readability in debug files. | |
| """ | |
| # Filter out verbose pre_api_call and post_api_call logs | |
| log_event_type = log_data.get("log_event_type") | |
| if log_event_type in ["pre_api_call", "post_api_call"]: | |
| return # Skip these verbose logs entirely | |
| # For successful calls or pre-call logs, a simple debug message is enough. | |
| if not log_data.get("exception"): | |
| sanitized_log = self._sanitize_litellm_log(log_data) | |
| # We log it at the DEBUG level to ensure it goes to the debug file | |
| # and not the console, based on the main.py configuration. | |
| lib_logger.debug(f"LiteLLM Log: {sanitized_log}") | |
| return | |
| # For failures, extract key info to make debug logs more readable. | |
| model = log_data.get("model", "N/A") | |
| call_id = log_data.get("litellm_call_id", "N/A") | |
| error_info = log_data.get("standard_logging_object", {}).get( | |
| "error_information", {} | |
| ) | |
| error_class = error_info.get("error_class", "UnknownError") | |
| error_message = error_info.get( | |
| "error_message", str(log_data.get("exception", "")) | |
| ) | |
| error_message = " ".join(error_message.split()) # Sanitize | |
| lib_logger.debug( | |
| f"LiteLLM Callback Handled Error: Model={model} | " | |
| f"Type={error_class} | Message='{error_message}'" | |
| ) | |
| async def __aenter__(self): | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| await self.close() | |
| async def close(self): | |
| """Close the HTTP client to prevent resource leaks.""" | |
| if hasattr(self, "http_client") and self.http_client: | |
| await self.http_client.aclose() | |
| def _convert_model_params(self, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Converts model parameters for specific providers. | |
| For example, the 'chutes' provider requires the model name to be prepended | |
| with 'openai/' and a specific 'api_base'. | |
| """ | |
| model = kwargs.get("model") | |
| if not model: | |
| return kwargs | |
| provider = model.split("/")[0] | |
| if provider == "chutes": | |
| kwargs["model"] = f"openai/{model.split('/', 1)[1]}" | |
| kwargs["api_base"] = "https://llm.chutes.ai/v1" | |
| return kwargs | |
| def _convert_model_params_for_litellm(self, **kwargs) -> Dict[str, Any]: | |
| """ | |
| Converts model parameters specifically for LiteLLM calls. | |
| This is called right before calling LiteLLM to handle custom providers. | |
| """ | |
| model = kwargs.get("model") | |
| if not model: | |
| return kwargs | |
| provider = model.split("/")[0] | |
| # Handle custom OpenAI-compatible providers | |
| # Check if this is a custom provider by looking for API_BASE environment variable | |
| import os | |
| api_base_env = f"{provider.upper()}_API_BASE" | |
| if os.getenv(api_base_env): | |
| # For custom providers, tell LiteLLM to use openai provider with custom model name | |
| # This preserves original model name in logs but converts for LiteLLM | |
| kwargs = kwargs.copy() # Don't modify original | |
| kwargs["model"] = f"openai/{model.split('/', 1)[1]}" | |
| kwargs["api_base"] = os.getenv(api_base_env).rstrip("/") | |
| kwargs["custom_llm_provider"] = "openai" | |
| return kwargs | |
| def _apply_default_safety_settings( | |
| self, litellm_kwargs: Dict[str, Any], provider: str | |
| ): | |
| """ | |
| Ensure default Gemini safety settings are present when calling the Gemini provider. | |
| This will not override any explicit settings provided by the request. It accepts | |
| either OpenAI-compatible generic `safety_settings` (dict) or direct Gemini-style | |
| `safetySettings` (list of dicts). Missing categories will be added with safe defaults. | |
| """ | |
| if provider != "gemini": | |
| return | |
| # Generic defaults (openai-compatible style) | |
| default_generic = { | |
| "harassment": "OFF", | |
| "hate_speech": "OFF", | |
| "sexually_explicit": "OFF", | |
| "dangerous_content": "OFF", | |
| "civic_integrity": "BLOCK_NONE", | |
| } | |
| # Gemini defaults (direct Gemini format) | |
| default_gemini = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, | |
| {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, | |
| ] | |
| # If generic form is present, ensure missing generic keys are filled in | |
| if "safety_settings" in litellm_kwargs and isinstance( | |
| litellm_kwargs["safety_settings"], dict | |
| ): | |
| for k, v in default_generic.items(): | |
| if k not in litellm_kwargs["safety_settings"]: | |
| litellm_kwargs["safety_settings"][k] = v | |
| return | |
| # If Gemini form is present, ensure missing gemini categories are appended | |
| if "safetySettings" in litellm_kwargs and isinstance( | |
| litellm_kwargs["safetySettings"], list | |
| ): | |
| present = { | |
| item.get("category") | |
| for item in litellm_kwargs["safetySettings"] | |
| if isinstance(item, dict) | |
| } | |
| for d in default_gemini: | |
| if d["category"] not in present: | |
| litellm_kwargs["safetySettings"].append(d) | |
| return | |
| # Neither present: set generic defaults so provider conversion will translate them | |
| if ( | |
| "safety_settings" not in litellm_kwargs | |
| and "safetySettings" not in litellm_kwargs | |
| ): | |
| litellm_kwargs["safety_settings"] = default_generic.copy() | |
| def get_oauth_credentials(self) -> Dict[str, List[str]]: | |
| return self.oauth_credentials | |
| def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool: | |
| """Checks if a provider is a custom OpenAI-compatible provider.""" | |
| import os | |
| # Check if the provider has an API_BASE environment variable | |
| api_base_env = f"{provider_name.upper()}_API_BASE" | |
| return os.getenv(api_base_env) is not None | |
| def _get_provider_instance(self, provider_name: str): | |
| """ | |
| Lazily initializes and returns a provider instance. | |
| Only initializes providers that have configured credentials. | |
| Args: | |
| provider_name: The name of the provider to get an instance for. | |
| For OAuth providers, this may include "_oauth" suffix | |
| (e.g., "antigravity_oauth"), but credentials are stored | |
| under the base name (e.g., "antigravity"). | |
| Returns: | |
| Provider instance if credentials exist, None otherwise. | |
| """ | |
| # For OAuth providers, credentials are stored under base name (without _oauth suffix) | |
| # e.g., "antigravity_oauth" plugin → credentials under "antigravity" | |
| credential_key = provider_name | |
| if provider_name.endswith("_oauth"): | |
| base_name = provider_name[:-6] # Remove "_oauth" | |
| if base_name in self.oauth_providers: | |
| credential_key = base_name | |
| # Only initialize providers for which we have credentials | |
| if credential_key not in self.all_credentials: | |
| lib_logger.debug( | |
| f"Skipping provider '{provider_name}' initialization: no credentials configured" | |
| ) | |
| return None | |
| if provider_name not in self._provider_instances: | |
| if provider_name in self._provider_plugins: | |
| self._provider_instances[provider_name] = self._provider_plugins[ | |
| provider_name | |
| ]() | |
| elif self._is_custom_openai_compatible_provider(provider_name): | |
| # Create a generic OpenAI-compatible provider for custom providers | |
| try: | |
| self._provider_instances[provider_name] = OpenAICompatibleProvider( | |
| provider_name | |
| ) | |
| except ValueError: | |
| # If the provider doesn't have the required environment variables, treat it as a standard provider | |
| return None | |
| else: | |
| return None | |
| return self._provider_instances[provider_name] | |
| def _resolve_model_id(self, model: str, provider: str) -> str: | |
| """ | |
| Resolves the actual model ID to send to the provider. | |
| For custom models with name/ID mappings, returns the ID. | |
| Otherwise, returns the model name unchanged. | |
| Args: | |
| model: Full model string with provider (e.g., "iflow/DS-v3.2") | |
| provider: Provider name (e.g., "iflow") | |
| Returns: | |
| Full model string with ID (e.g., "iflow/deepseek-v3.2") | |
| """ | |
| # Extract model name from "provider/model_name" format | |
| model_name = model.split("/")[-1] if "/" in model else model | |
| # Try to get provider instance to check for model definitions | |
| provider_plugin = self._get_provider_instance(provider) | |
| # Check if provider has model definitions | |
| if provider_plugin and hasattr(provider_plugin, "model_definitions"): | |
| model_id = provider_plugin.model_definitions.get_model_id( | |
| provider, model_name | |
| ) | |
| if model_id and model_id != model_name: | |
| # Return with provider prefix | |
| return f"{provider}/{model_id}" | |
| # Fallback: use client's own model definitions | |
| model_id = self.model_definitions.get_model_id(provider, model_name) | |
| if model_id and model_id != model_name: | |
| return f"{provider}/{model_id}" | |
| # No conversion needed, return original | |
| return model | |
| async def _safe_streaming_wrapper( | |
| self, | |
| stream: Any, | |
| key: str, | |
| model: str, | |
| request: Optional[Any] = None, | |
| provider_plugin: Optional[Any] = None, | |
| ) -> AsyncGenerator[Any, None]: | |
| """ | |
| A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully, | |
| and distinguishes between content and streamed errors. | |
| FINISH_REASON HANDLING: | |
| Providers just translate chunks - this wrapper handles ALL finish_reason logic: | |
| 1. Strip finish_reason from intermediate chunks (litellm defaults to "stop") | |
| 2. Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop | |
| 3. Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) | |
| """ | |
| last_usage = None | |
| stream_completed = False | |
| stream_iterator = stream.__aiter__() | |
| json_buffer = "" | |
| accumulated_finish_reason = None # Track strongest finish_reason across chunks | |
| has_tool_calls = False # Track if ANY tool calls were seen in stream | |
| try: | |
| while True: | |
| if request and await request.is_disconnected(): | |
| lib_logger.info( | |
| f"Client disconnected. Aborting stream for credential {mask_credential(key)}." | |
| ) | |
| break | |
| try: | |
| chunk = await stream_iterator.__anext__() | |
| if json_buffer: | |
| lib_logger.warning( | |
| f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}" | |
| ) | |
| json_buffer = "" | |
| # Convert chunk to dict, handling both litellm.ModelResponse and raw dicts | |
| if hasattr(chunk, "dict"): | |
| chunk_dict = chunk.dict() | |
| elif hasattr(chunk, "model_dump"): | |
| chunk_dict = chunk.model_dump() | |
| else: | |
| chunk_dict = chunk | |
| # === FINISH_REASON LOGIC === | |
| # Providers send raw chunks without finish_reason logic. | |
| # This wrapper determines finish_reason based on accumulated state. | |
| if "choices" in chunk_dict and chunk_dict["choices"]: | |
| choice = chunk_dict["choices"][0] | |
| delta = choice.get("delta", {}) | |
| usage = chunk_dict.get("usage", {}) | |
| # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls | |
| if delta.get("tool_calls"): | |
| has_tool_calls = True | |
| accumulated_finish_reason = "tool_calls" | |
| # Detect final chunk: has usage with completion_tokens > 0 | |
| has_completion_tokens = ( | |
| usage | |
| and isinstance(usage, dict) | |
| and usage.get("completion_tokens", 0) > 0 | |
| ) | |
| if has_completion_tokens: | |
| # FINAL CHUNK: Determine correct finish_reason | |
| if has_tool_calls: | |
| # Tool calls always win | |
| choice["finish_reason"] = "tool_calls" | |
| elif accumulated_finish_reason: | |
| # Use accumulated reason (length, content_filter, etc.) | |
| choice["finish_reason"] = accumulated_finish_reason | |
| else: | |
| # Default to stop | |
| choice["finish_reason"] = "stop" | |
| else: | |
| # INTERMEDIATE CHUNK: Never emit finish_reason | |
| # (litellm.ModelResponse defaults to "stop" which is wrong) | |
| choice["finish_reason"] = None | |
| yield f"data: {json.dumps(chunk_dict)}\n\n" | |
| if hasattr(chunk, "usage") and chunk.usage: | |
| last_usage = chunk.usage | |
| except StopAsyncIteration: | |
| stream_completed = True | |
| if json_buffer: | |
| lib_logger.info( | |
| f"Stream ended with incomplete data in buffer: {json_buffer}" | |
| ) | |
| if last_usage: | |
| # Create a dummy ModelResponse for recording (only usage matters) | |
| dummy_response = litellm.ModelResponse(usage=last_usage) | |
| await self.usage_manager.record_success( | |
| key, model, dummy_response | |
| ) | |
| else: | |
| # If no usage seen (rare), record success without tokens/cost | |
| await self.usage_manager.record_success(key, model) | |
| break | |
| except CredentialNeedsReauthError as e: | |
| # This credential needs re-authentication but re-auth is already queued. | |
| # Wrap it so the outer retry loop can rotate to the next credential. | |
| # No scary traceback needed - this is an expected recovery scenario. | |
| raise StreamedAPIError("Credential needs re-authentication", data=e) | |
| except ( | |
| litellm.RateLimitError, | |
| litellm.ServiceUnavailableError, | |
| litellm.InternalServerError, | |
| APIConnectionError, | |
| httpx.HTTPStatusError, | |
| ) as e: | |
| # This is a critical, typed error from litellm or httpx that signals a key failure. | |
| # We do not try to parse it here. We wrap it and raise it immediately | |
| # for the outer retry loop to handle. | |
| lib_logger.warning( | |
| f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation." | |
| ) | |
| raise StreamedAPIError("Provider error received in stream", data=e) | |
| except Exception as e: | |
| try: | |
| raw_chunk = "" | |
| # Google streams errors inside a bytes representation (b'{...}'). | |
| # We use regex to extract the content, which is more reliable than splitting. | |
| match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL) | |
| if match: | |
| # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it. | |
| raw_chunk = codecs.decode(match.group(1), "unicode_escape") | |
| else: | |
| # Fallback for other potential error formats that use "Received chunk:". | |
| chunk_from_split = ( | |
| str(e).split("Received chunk:")[-1].strip() | |
| ) | |
| if chunk_from_split != str( | |
| e | |
| ): # Ensure the split actually did something | |
| raw_chunk = chunk_from_split | |
| if not raw_chunk: | |
| # If we could not extract a valid chunk, we cannot proceed with reassembly. | |
| # This indicates a different, unexpected error type. Re-raise it. | |
| raise e | |
| # Append the clean chunk to the buffer and try to parse. | |
| json_buffer += raw_chunk | |
| parsed_data = json.loads(json_buffer) | |
| # If parsing succeeds, we have the complete object. | |
| lib_logger.info( | |
| f"Successfully reassembled JSON from stream: {json_buffer}" | |
| ) | |
| # Wrap the complete error object and raise it. The outer function will decide how to handle it. | |
| raise StreamedAPIError( | |
| "Provider error received in stream", data=parsed_data | |
| ) | |
| except json.JSONDecodeError: | |
| # This is the expected outcome if the JSON in the buffer is not yet complete. | |
| lib_logger.info( | |
| f"Buffer still incomplete. Waiting for more chunks: {json_buffer}" | |
| ) | |
| continue # Continue to the next loop to get the next chunk. | |
| except StreamedAPIError: | |
| # Re-raise to be caught by the outer retry handler. | |
| raise | |
| except Exception as buffer_exc: | |
| # If the error was not a JSONDecodeError, it's an unexpected internal error. | |
| lib_logger.error( | |
| f"Error during stream buffering logic: {buffer_exc}. Discarding buffer." | |
| ) | |
| json_buffer = ( | |
| "" # Clear the corrupted buffer to prevent further issues. | |
| ) | |
| raise buffer_exc | |
| except StreamedAPIError: | |
| # This is caught by the acompletion retry logic. | |
| # We re-raise it to ensure it's not caught by the generic 'except Exception'. | |
| raise | |
| except Exception as e: | |
| # Catch any other unexpected errors during streaming. | |
| lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}") | |
| lib_logger.error( | |
| f"An unexpected error occurred during the stream for credential {mask_credential(key)}: {e}" | |
| ) | |
| # We still need to raise it so the client knows something went wrong. | |
| raise | |
| finally: | |
| # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.). | |
| # The primary goal is to ensure usage is always logged internally. | |
| await self.usage_manager.release_key(key, model) | |
| lib_logger.info( | |
| f"STREAM FINISHED and lock released for credential {mask_credential(key)}." | |
| ) | |
| # Only send [DONE] if the stream completed naturally and the client is still there. | |
| # This prevents sending [DONE] to a disconnected client or after an error. | |
| if stream_completed and ( | |
| not request or not await request.is_disconnected() | |
| ): | |
| yield "data: [DONE]\n\n" | |
| async def _execute_with_retry( | |
| self, | |
| api_call: callable, | |
| request: Optional[Any], | |
| pre_request_callback: Optional[callable] = None, | |
| **kwargs, | |
| ) -> Any: | |
| """A generic retry mechanism for non-streaming API calls.""" | |
| model = kwargs.get("model") | |
| if not model: | |
| raise ValueError("'model' is a required parameter.") | |
| provider = model.split("/")[0] | |
| if provider not in self.all_credentials: | |
| raise ValueError( | |
| f"No API keys or OAuth credentials configured for provider: {provider}" | |
| ) | |
| # Establish a global deadline for the entire request lifecycle. | |
| deadline = time.time() + self.global_timeout | |
| # Create a mutable copy of the keys and shuffle it to ensure | |
| # that the key selection is randomized, which is crucial when | |
| # multiple keys have the same usage stats. | |
| credentials_for_provider = list(self.all_credentials[provider]) | |
| random.shuffle(credentials_for_provider) | |
| # Filter out credentials that are unavailable (queued for re-auth) | |
| provider_plugin = self._get_provider_instance(provider) | |
| if provider_plugin and hasattr(provider_plugin, "is_credential_available"): | |
| available_creds = [ | |
| cred | |
| for cred in credentials_for_provider | |
| if provider_plugin.is_credential_available(cred) | |
| ] | |
| if available_creds: | |
| credentials_for_provider = available_creds | |
| # If all credentials are unavailable, keep the original list | |
| # (better to try unavailable creds than fail immediately) | |
| tried_creds = set() | |
| last_exception = None | |
| kwargs = self._convert_model_params(**kwargs) | |
| # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded. | |
| # Resolve model ID early, before any credential operations | |
| # This ensures consistent model ID usage for acquisition, release, and tracking | |
| resolved_model = self._resolve_model_id(model, provider) | |
| if resolved_model != model: | |
| lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") | |
| model = resolved_model | |
| kwargs["model"] = model # Ensure kwargs has the resolved model for litellm | |
| # [NEW] Filter by model tier requirement and build priority map | |
| credential_priorities = None | |
| if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): | |
| required_tier = provider_plugin.get_model_tier_requirement(model) | |
| if required_tier is not None: | |
| # Filter OUT only credentials we KNOW are too low priority | |
| # Keep credentials with unknown priority (None) - they might be high priority | |
| incompatible_creds = [] | |
| compatible_creds = [] | |
| unknown_creds = [] | |
| for cred in credentials_for_provider: | |
| if hasattr(provider_plugin, "get_credential_priority"): | |
| priority = provider_plugin.get_credential_priority(cred) | |
| if priority is None: | |
| # Unknown priority - keep it, will be discovered on first use | |
| unknown_creds.append(cred) | |
| elif priority <= required_tier: | |
| # Known compatible priority | |
| compatible_creds.append(cred) | |
| else: | |
| # Known incompatible priority (too low) | |
| incompatible_creds.append(cred) | |
| else: | |
| # Provider doesn't support priorities - keep all | |
| unknown_creds.append(cred) | |
| # If we have any known-compatible or unknown credentials, use them | |
| tier_compatible_creds = compatible_creds + unknown_creds | |
| if tier_compatible_creds: | |
| credentials_for_provider = tier_compatible_creds | |
| if compatible_creds and unknown_creds: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." | |
| ) | |
| elif compatible_creds: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(compatible_creds)} known-compatible credentials." | |
| ) | |
| else: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." | |
| ) | |
| elif incompatible_creds: | |
| # Only known-incompatible credentials remain | |
| lib_logger.warning( | |
| f"Model {model} requires priority <= {required_tier} credentials, " | |
| f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " | |
| f"Request will likely fail." | |
| ) | |
| # Build priority map and tier names map for usage_manager | |
| credential_tier_names = None | |
| if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): | |
| credential_priorities = {} | |
| credential_tier_names = {} | |
| for cred in credentials_for_provider: | |
| priority = provider_plugin.get_credential_priority(cred) | |
| if priority is not None: | |
| credential_priorities[cred] = priority | |
| # Also get tier name for logging | |
| if hasattr(provider_plugin, "get_credential_tier_name"): | |
| tier_name = provider_plugin.get_credential_tier_name(cred) | |
| if tier_name: | |
| credential_tier_names[cred] = tier_name | |
| if credential_priorities: | |
| lib_logger.debug( | |
| f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" | |
| ) | |
| # Initialize error accumulator for tracking errors across credential rotation | |
| error_accumulator = RequestErrorAccumulator() | |
| error_accumulator.model = model | |
| error_accumulator.provider = provider | |
| while ( | |
| len(tried_creds) < len(credentials_for_provider) and time.time() < deadline | |
| ): | |
| current_cred = None | |
| key_acquired = False | |
| try: | |
| # Check for a provider-wide cooldown first. | |
| if await self.cooldown_manager.is_cooling_down(provider): | |
| remaining_cooldown = ( | |
| await self.cooldown_manager.get_cooldown_remaining(provider) | |
| ) | |
| remaining_budget = deadline - time.time() | |
| # If the cooldown is longer than the remaining time budget, fail fast. | |
| if remaining_cooldown > remaining_budget: | |
| lib_logger.warning( | |
| f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds." | |
| ) | |
| await asyncio.sleep(remaining_cooldown) | |
| creds_to_try = [ | |
| c for c in credentials_for_provider if c not in tried_creds | |
| ] | |
| if not creds_to_try: | |
| break | |
| # Get count of credentials not on cooldown for this model | |
| available_creds = ( | |
| await self.usage_manager.get_available_credentials_for_model( | |
| creds_to_try, model | |
| ) | |
| ) | |
| available_count = len(available_creds) | |
| total_count = len(credentials_for_provider) | |
| lib_logger.info( | |
| f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{available_count}({total_count})" | |
| ) | |
| max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) | |
| current_cred = await self.usage_manager.acquire_key( | |
| available_keys=creds_to_try, | |
| model=model, | |
| deadline=deadline, | |
| max_concurrent=max_concurrent, | |
| credential_priorities=credential_priorities, | |
| credential_tier_names=credential_tier_names, | |
| ) | |
| key_acquired = True | |
| tried_creds.add(current_cred) | |
| litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy()) | |
| # [NEW] Merge provider-specific params | |
| if provider in self.litellm_provider_params: | |
| litellm_kwargs["litellm_params"] = { | |
| **self.litellm_provider_params[provider], | |
| **litellm_kwargs.get("litellm_params", {}), | |
| } | |
| provider_plugin = self._get_provider_instance(provider) | |
| # Model ID is already resolved before the loop, and kwargs['model'] is updated. | |
| # No further resolution needed here. | |
| # Apply model-specific options for custom providers | |
| if provider_plugin and hasattr(provider_plugin, "get_model_options"): | |
| model_options = provider_plugin.get_model_options(model) | |
| if model_options: | |
| # Merge model options into litellm_kwargs | |
| for key, value in model_options.items(): | |
| if key == "reasoning_effort": | |
| litellm_kwargs["reasoning_effort"] = value | |
| elif key not in litellm_kwargs: | |
| litellm_kwargs[key] = value | |
| if provider_plugin and provider_plugin.has_custom_logic(): | |
| lib_logger.debug( | |
| f"Provider '{provider}' has custom logic. Delegating call." | |
| ) | |
| litellm_kwargs["credential_identifier"] = current_cred | |
| litellm_kwargs["enable_request_logging"] = ( | |
| self.enable_request_logging | |
| ) | |
| # Check body first for custom_reasoning_budget | |
| if "custom_reasoning_budget" in kwargs: | |
| litellm_kwargs["custom_reasoning_budget"] = kwargs[ | |
| "custom_reasoning_budget" | |
| ] | |
| else: | |
| custom_budget_header = None | |
| if request and hasattr(request, "headers"): | |
| custom_budget_header = request.headers.get( | |
| "custom_reasoning_budget" | |
| ) | |
| if custom_budget_header is not None: | |
| is_budget_enabled = custom_budget_header.lower() == "true" | |
| litellm_kwargs["custom_reasoning_budget"] = ( | |
| is_budget_enabled | |
| ) | |
| # Retry loop for custom providers - mirrors streaming path error handling | |
| for attempt in range(self.max_retries): | |
| try: | |
| lib_logger.info( | |
| f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" | |
| ) | |
| if pre_request_callback: | |
| try: | |
| await pre_request_callback(request, litellm_kwargs) | |
| except Exception as e: | |
| if self.abort_on_callback_error: | |
| raise PreRequestCallbackError( | |
| f"Pre-request callback failed: {e}" | |
| ) from e | |
| else: | |
| lib_logger.warning( | |
| f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" | |
| ) | |
| response = await provider_plugin.acompletion( | |
| self.http_client, **litellm_kwargs | |
| ) | |
| # For non-streaming, success is immediate | |
| await self.usage_manager.record_success( | |
| current_cred, model, response | |
| ) | |
| await self.usage_manager.release_key(current_cred, model) | |
| key_acquired = False | |
| return response | |
| except ( | |
| litellm.RateLimitError, | |
| httpx.HTTPStatusError, | |
| ) as e: | |
| last_exception = e | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| # Record in accumulator for client reporting | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}) during custom provider call. Failing." | |
| ) | |
| raise last_exception | |
| # Handle rate limits with cooldown (exclude quota_exceeded) | |
| if classified_error.error_type == "rate_limit": | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." | |
| ) | |
| break # Rotate to next credential | |
| except ( | |
| APIConnectionError, | |
| litellm.InternalServerError, | |
| litellm.ServiceUnavailableError, | |
| ) as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| # Provider-level error: don't increment consecutive failures | |
| await self.usage_manager.record_failure( | |
| current_cred, | |
| model, | |
| classified_error, | |
| increment_consecutive_failures=False, | |
| ) | |
| if attempt >= self.max_retries - 1: | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." | |
| ) | |
| break | |
| wait_time = classified_error.retry_after or ( | |
| 2**attempt | |
| ) + random.uniform(0, 1) | |
| remaining_budget = deadline - time.time() | |
| if wait_time > remaining_budget: | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| except Exception as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| # Record in accumulator | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}). Failing." | |
| ) | |
| raise last_exception | |
| # Handle rate limits with cooldown (exclude quota_exceeded) | |
| if ( | |
| classified_error.status_code == 429 | |
| and classified_error.error_type != "quota_exceeded" | |
| ) or classified_error.error_type == "rate_limit": | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| break # Rotate to next credential | |
| # If the inner loop breaks, it means the key failed and we need to rotate. | |
| # Continue to the next iteration of the outer while loop to pick a new key. | |
| continue | |
| else: # This is the standard API Key / litellm-handled provider logic | |
| is_oauth = provider in self.oauth_providers | |
| if is_oauth: # Standard OAuth provider (not custom) | |
| # ... (logic to set headers) ... | |
| pass | |
| else: # API Key | |
| litellm_kwargs["api_key"] = current_cred | |
| provider_instance = self._get_provider_instance(provider) | |
| if provider_instance: | |
| # Ensure default Gemini safety settings are present (without overriding request) | |
| try: | |
| self._apply_default_safety_settings( | |
| litellm_kwargs, provider | |
| ) | |
| except Exception: | |
| # If anything goes wrong here, avoid breaking the request flow. | |
| lib_logger.debug( | |
| "Could not apply default safety settings; continuing." | |
| ) | |
| if "safety_settings" in litellm_kwargs: | |
| converted_settings = ( | |
| provider_instance.convert_safety_settings( | |
| litellm_kwargs["safety_settings"] | |
| ) | |
| ) | |
| if converted_settings is not None: | |
| litellm_kwargs["safety_settings"] = converted_settings | |
| else: | |
| del litellm_kwargs["safety_settings"] | |
| if provider == "gemini" and provider_instance: | |
| provider_instance.handle_thinking_parameter( | |
| litellm_kwargs, model | |
| ) | |
| if provider == "nvidia_nim" and provider_instance: | |
| provider_instance.handle_thinking_parameter( | |
| litellm_kwargs, model | |
| ) | |
| if "gemma-3" in model and "messages" in litellm_kwargs: | |
| litellm_kwargs["messages"] = [ | |
| {"role": "user", "content": m["content"]} | |
| if m.get("role") == "system" | |
| else m | |
| for m in litellm_kwargs["messages"] | |
| ] | |
| litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) | |
| for attempt in range(self.max_retries): | |
| try: | |
| lib_logger.info( | |
| f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" | |
| ) | |
| if pre_request_callback: | |
| try: | |
| await pre_request_callback(request, litellm_kwargs) | |
| except Exception as e: | |
| if self.abort_on_callback_error: | |
| raise PreRequestCallbackError( | |
| f"Pre-request callback failed: {e}" | |
| ) from e | |
| else: | |
| lib_logger.warning( | |
| f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" | |
| ) | |
| # Convert model parameters for custom providers right before LiteLLM call | |
| final_kwargs = self._convert_model_params_for_litellm( | |
| **litellm_kwargs | |
| ) | |
| response = await api_call( | |
| **final_kwargs, | |
| logger_fn=self._litellm_logger_callback, | |
| ) | |
| await self.usage_manager.record_success( | |
| current_cred, model, response | |
| ) | |
| await self.usage_manager.release_key(current_cred, model) | |
| key_acquired = False | |
| return response | |
| except litellm.RateLimitError as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| # Extract a clean error message for the user-facing log | |
| error_message = str(e).split("\n")[0] | |
| # Record in accumulator for client reporting | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.info( | |
| f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key." | |
| ) | |
| # Only trigger provider-wide cooldown for rate limits, not quota issues | |
| if ( | |
| classified_error.status_code == 429 | |
| and classified_error.error_type != "quota_exceeded" | |
| ): | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| break # Move to the next key | |
| except ( | |
| APIConnectionError, | |
| litellm.InternalServerError, | |
| litellm.ServiceUnavailableError, | |
| ) as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| # Provider-level error: don't increment consecutive failures | |
| await self.usage_manager.record_failure( | |
| current_cred, | |
| model, | |
| classified_error, | |
| increment_consecutive_failures=False, | |
| ) | |
| if attempt >= self.max_retries - 1: | |
| # Record in accumulator only on final failure for this key | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating." | |
| ) | |
| break # Move to the next key | |
| # For temporary errors, wait before retrying with the same key. | |
| wait_time = classified_error.retry_after or ( | |
| 2**attempt | |
| ) + random.uniform(0, 1) | |
| remaining_budget = deadline - time.time() | |
| # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately. | |
| if wait_time > remaining_budget: | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Key {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| continue # Retry with the same key | |
| except httpx.HTTPStatusError as e: | |
| # Handle HTTP errors from httpx (e.g., from custom providers like Antigravity) | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| lib_logger.warning( | |
| f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})." | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}). Failing request." | |
| ) | |
| raise last_exception | |
| # Record in accumulator after confirming it's a rotatable error | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) | |
| if classified_error.error_type == "rate_limit": | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| # Check if we should retry same key (server errors with retries left) | |
| if ( | |
| should_retry_same_key(classified_error) | |
| and attempt < self.max_retries - 1 | |
| ): | |
| wait_time = classified_error.retry_after or ( | |
| 2**attempt | |
| ) + random.uniform(0, 1) | |
| remaining_budget = deadline - time.time() | |
| if wait_time <= remaining_budget: | |
| lib_logger.warning( | |
| f"Server error, retrying same key in {wait_time:.2f}s." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| # Record failure and rotate to next key | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| lib_logger.info( | |
| f"Rotating to next key after {classified_error.error_type} error." | |
| ) | |
| break | |
| except Exception as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| if request and await request.is_disconnected(): | |
| lib_logger.warning( | |
| f"Client disconnected. Aborting retries for {mask_credential(current_cred)}." | |
| ) | |
| raise last_exception | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| lib_logger.warning( | |
| f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." | |
| ) | |
| # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) | |
| if ( | |
| classified_error.status_code == 429 | |
| and classified_error.error_type != "quota_exceeded" | |
| ) or classified_error.error_type == "rate_limit": | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}). Failing request." | |
| ) | |
| raise last_exception | |
| # Record in accumulator after confirming it's a rotatable error | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| break # Try next key for other errors | |
| finally: | |
| if key_acquired and current_cred: | |
| await self.usage_manager.release_key(current_cred, model) | |
| # Check if we exhausted all credentials or timed out | |
| if time.time() >= deadline: | |
| error_accumulator.timeout_occurred = True | |
| if error_accumulator.has_errors(): | |
| # Log concise summary for server logs | |
| lib_logger.error(error_accumulator.build_log_message()) | |
| # Return the structured error response for the client | |
| return error_accumulator.build_client_error_response() | |
| # Return None to indicate failure without error details (shouldn't normally happen) | |
| lib_logger.warning( | |
| "Unexpected state: request failed with no recorded errors. " | |
| "This may indicate a logic error in error tracking." | |
| ) | |
| return None | |
| async def _streaming_acompletion_with_retry( | |
| self, | |
| request: Optional[Any], | |
| pre_request_callback: Optional[callable] = None, | |
| **kwargs, | |
| ) -> AsyncGenerator[str, None]: | |
| """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" | |
| model = kwargs.get("model") | |
| provider = model.split("/")[0] | |
| # Create a mutable copy of the keys and shuffle it. | |
| credentials_for_provider = list(self.all_credentials[provider]) | |
| random.shuffle(credentials_for_provider) | |
| # Filter out credentials that are unavailable (queued for re-auth) | |
| provider_plugin = self._get_provider_instance(provider) | |
| if provider_plugin and hasattr(provider_plugin, "is_credential_available"): | |
| available_creds = [ | |
| cred | |
| for cred in credentials_for_provider | |
| if provider_plugin.is_credential_available(cred) | |
| ] | |
| if available_creds: | |
| credentials_for_provider = available_creds | |
| # If all credentials are unavailable, keep the original list | |
| # (better to try unavailable creds than fail immediately) | |
| deadline = time.time() + self.global_timeout | |
| tried_creds = set() | |
| last_exception = None | |
| kwargs = self._convert_model_params(**kwargs) | |
| consecutive_quota_failures = 0 | |
| # Resolve model ID early, before any credential operations | |
| # This ensures consistent model ID usage for acquisition, release, and tracking | |
| resolved_model = self._resolve_model_id(model, provider) | |
| if resolved_model != model: | |
| lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") | |
| model = resolved_model | |
| kwargs["model"] = model # Ensure kwargs has the resolved model for litellm | |
| # [NEW] Filter by model tier requirement and build priority map | |
| credential_priorities = None | |
| if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): | |
| required_tier = provider_plugin.get_model_tier_requirement(model) | |
| if required_tier is not None: | |
| # Filter OUT only credentials we KNOW are too low priority | |
| # Keep credentials with unknown priority (None) - they might be high priority | |
| incompatible_creds = [] | |
| compatible_creds = [] | |
| unknown_creds = [] | |
| for cred in credentials_for_provider: | |
| if hasattr(provider_plugin, "get_credential_priority"): | |
| priority = provider_plugin.get_credential_priority(cred) | |
| if priority is None: | |
| # Unknown priority - keep it, will be discovered on first use | |
| unknown_creds.append(cred) | |
| elif priority <= required_tier: | |
| # Known compatible priority | |
| compatible_creds.append(cred) | |
| else: | |
| # Known incompatible priority (too low) | |
| incompatible_creds.append(cred) | |
| else: | |
| # Provider doesn't support priorities - keep all | |
| unknown_creds.append(cred) | |
| # If we have any known-compatible or unknown credentials, use them | |
| tier_compatible_creds = compatible_creds + unknown_creds | |
| if tier_compatible_creds: | |
| credentials_for_provider = tier_compatible_creds | |
| if compatible_creds and unknown_creds: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." | |
| ) | |
| elif compatible_creds: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(compatible_creds)} known-compatible credentials." | |
| ) | |
| else: | |
| lib_logger.info( | |
| f"Model {model} requires priority <= {required_tier}. " | |
| f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." | |
| ) | |
| elif incompatible_creds: | |
| # Only known-incompatible credentials remain | |
| lib_logger.warning( | |
| f"Model {model} requires priority <= {required_tier} credentials, " | |
| f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " | |
| f"Request will likely fail." | |
| ) | |
| # Build priority map and tier names map for usage_manager | |
| credential_tier_names = None | |
| if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): | |
| credential_priorities = {} | |
| credential_tier_names = {} | |
| for cred in credentials_for_provider: | |
| priority = provider_plugin.get_credential_priority(cred) | |
| if priority is not None: | |
| credential_priorities[cred] = priority | |
| # Also get tier name for logging | |
| if hasattr(provider_plugin, "get_credential_tier_name"): | |
| tier_name = provider_plugin.get_credential_tier_name(cred) | |
| if tier_name: | |
| credential_tier_names[cred] = tier_name | |
| if credential_priorities: | |
| lib_logger.debug( | |
| f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" | |
| ) | |
| # Initialize error accumulator for tracking errors across credential rotation | |
| error_accumulator = RequestErrorAccumulator() | |
| error_accumulator.model = model | |
| error_accumulator.provider = provider | |
| try: | |
| while ( | |
| len(tried_creds) < len(credentials_for_provider) | |
| and time.time() < deadline | |
| ): | |
| current_cred = None | |
| key_acquired = False | |
| try: | |
| if await self.cooldown_manager.is_cooling_down(provider): | |
| remaining_cooldown = ( | |
| await self.cooldown_manager.get_cooldown_remaining(provider) | |
| ) | |
| remaining_budget = deadline - time.time() | |
| if remaining_cooldown > remaining_budget: | |
| lib_logger.warning( | |
| f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds." | |
| ) | |
| await asyncio.sleep(remaining_cooldown) | |
| creds_to_try = [ | |
| c for c in credentials_for_provider if c not in tried_creds | |
| ] | |
| if not creds_to_try: | |
| lib_logger.warning( | |
| f"All credentials for provider {provider} have been tried. No more credentials to rotate to." | |
| ) | |
| break | |
| # Get count of credentials not on cooldown for this model | |
| available_creds = ( | |
| await self.usage_manager.get_available_credentials_for_model( | |
| creds_to_try, model | |
| ) | |
| ) | |
| available_count = len(available_creds) | |
| total_count = len(credentials_for_provider) | |
| lib_logger.info( | |
| f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{available_count}({total_count})" | |
| ) | |
| max_concurrent = self.max_concurrent_requests_per_key.get( | |
| provider, 1 | |
| ) | |
| current_cred = await self.usage_manager.acquire_key( | |
| available_keys=creds_to_try, | |
| model=model, | |
| deadline=deadline, | |
| max_concurrent=max_concurrent, | |
| credential_priorities=credential_priorities, | |
| credential_tier_names=credential_tier_names, | |
| ) | |
| key_acquired = True | |
| tried_creds.add(current_cred) | |
| litellm_kwargs = self.all_providers.get_provider_kwargs( | |
| **kwargs.copy() | |
| ) | |
| if "reasoning_effort" in kwargs: | |
| litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"] | |
| # Check body first for custom_reasoning_budget | |
| if "custom_reasoning_budget" in kwargs: | |
| litellm_kwargs["custom_reasoning_budget"] = kwargs[ | |
| "custom_reasoning_budget" | |
| ] | |
| else: | |
| custom_budget_header = None | |
| if request and hasattr(request, "headers"): | |
| custom_budget_header = request.headers.get( | |
| "custom_reasoning_budget" | |
| ) | |
| if custom_budget_header is not None: | |
| is_budget_enabled = custom_budget_header.lower() == "true" | |
| litellm_kwargs["custom_reasoning_budget"] = ( | |
| is_budget_enabled | |
| ) | |
| # [NEW] Merge provider-specific params | |
| if provider in self.litellm_provider_params: | |
| litellm_kwargs["litellm_params"] = { | |
| **self.litellm_provider_params[provider], | |
| **litellm_kwargs.get("litellm_params", {}), | |
| } | |
| provider_plugin = self._get_provider_instance(provider) | |
| # Model ID is already resolved before the loop, and kwargs['model'] is updated. | |
| # No further resolution needed here. | |
| # Apply model-specific options for custom providers | |
| if provider_plugin and hasattr( | |
| provider_plugin, "get_model_options" | |
| ): | |
| model_options = provider_plugin.get_model_options(model) | |
| if model_options: | |
| # Merge model options into litellm_kwargs | |
| for key, value in model_options.items(): | |
| if key == "reasoning_effort": | |
| litellm_kwargs["reasoning_effort"] = value | |
| elif key not in litellm_kwargs: | |
| litellm_kwargs[key] = value | |
| if provider_plugin and provider_plugin.has_custom_logic(): | |
| lib_logger.debug( | |
| f"Provider '{provider}' has custom logic. Delegating call." | |
| ) | |
| litellm_kwargs["credential_identifier"] = current_cred | |
| litellm_kwargs["enable_request_logging"] = ( | |
| self.enable_request_logging | |
| ) | |
| for attempt in range(self.max_retries): | |
| try: | |
| lib_logger.info( | |
| f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" | |
| ) | |
| if pre_request_callback: | |
| try: | |
| await pre_request_callback( | |
| request, litellm_kwargs | |
| ) | |
| except Exception as e: | |
| if self.abort_on_callback_error: | |
| raise PreRequestCallbackError( | |
| f"Pre-request callback failed: {e}" | |
| ) from e | |
| else: | |
| lib_logger.warning( | |
| f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" | |
| ) | |
| response = await provider_plugin.acompletion( | |
| self.http_client, **litellm_kwargs | |
| ) | |
| lib_logger.info( | |
| f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." | |
| ) | |
| key_acquired = False | |
| stream_generator = self._safe_streaming_wrapper( | |
| response, | |
| current_cred, | |
| model, | |
| request, | |
| provider_plugin, | |
| ) | |
| async for chunk in stream_generator: | |
| yield chunk | |
| return | |
| except ( | |
| StreamedAPIError, | |
| litellm.RateLimitError, | |
| httpx.HTTPStatusError, | |
| ) as e: | |
| last_exception = e | |
| # If the exception is our custom wrapper, unwrap the original error | |
| original_exc = getattr(e, "data", e) | |
| classified_error = classify_error( | |
| original_exc, provider=provider | |
| ) | |
| error_message = str(original_exc).split("\n")[0] | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| # Record in accumulator for client reporting | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing." | |
| ) | |
| raise last_exception | |
| # Handle rate limits with cooldown (exclude quota_exceeded) | |
| if classified_error.error_type == "rate_limit": | |
| cooldown_duration = ( | |
| classified_error.retry_after or 60 | |
| ) | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." | |
| ) | |
| break | |
| except ( | |
| APIConnectionError, | |
| litellm.InternalServerError, | |
| litellm.ServiceUnavailableError, | |
| ) as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| # Provider-level error: don't increment consecutive failures | |
| await self.usage_manager.record_failure( | |
| current_cred, | |
| model, | |
| classified_error, | |
| increment_consecutive_failures=False, | |
| ) | |
| if attempt >= self.max_retries - 1: | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." | |
| ) | |
| break | |
| wait_time = classified_error.retry_after or ( | |
| 2**attempt | |
| ) + random.uniform(0, 1) | |
| remaining_budget = deadline - time.time() | |
| if wait_time > remaining_budget: | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| except Exception as e: | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message = str(e).split("\n")[0] | |
| # Record in accumulator | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message | |
| ) | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}). Failing." | |
| ) | |
| raise last_exception | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| break | |
| # If the inner loop breaks, it means the key failed and we need to rotate. | |
| # Continue to the next iteration of the outer while loop to pick a new key. | |
| continue | |
| else: # This is the standard API Key / litellm-handled provider logic | |
| is_oauth = provider in self.oauth_providers | |
| if is_oauth: # Standard OAuth provider (not custom) | |
| # ... (logic to set headers) ... | |
| pass | |
| else: # API Key | |
| litellm_kwargs["api_key"] = current_cred | |
| provider_instance = self._get_provider_instance(provider) | |
| if provider_instance: | |
| # Ensure default Gemini safety settings are present (without overriding request) | |
| try: | |
| self._apply_default_safety_settings( | |
| litellm_kwargs, provider | |
| ) | |
| except Exception: | |
| lib_logger.debug( | |
| "Could not apply default safety settings for streaming path; continuing." | |
| ) | |
| if "safety_settings" in litellm_kwargs: | |
| converted_settings = ( | |
| provider_instance.convert_safety_settings( | |
| litellm_kwargs["safety_settings"] | |
| ) | |
| ) | |
| if converted_settings is not None: | |
| litellm_kwargs["safety_settings"] = converted_settings | |
| else: | |
| del litellm_kwargs["safety_settings"] | |
| if provider == "gemini" and provider_instance: | |
| provider_instance.handle_thinking_parameter( | |
| litellm_kwargs, model | |
| ) | |
| if provider == "nvidia_nim" and provider_instance: | |
| provider_instance.handle_thinking_parameter( | |
| litellm_kwargs, model | |
| ) | |
| if "gemma-3" in model and "messages" in litellm_kwargs: | |
| litellm_kwargs["messages"] = [ | |
| {"role": "user", "content": m["content"]} | |
| if m.get("role") == "system" | |
| else m | |
| for m in litellm_kwargs["messages"] | |
| ] | |
| litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) | |
| # If the provider is 'qwen_code', set the custom provider to 'qwen' | |
| # and strip the prefix from the model name for LiteLLM. | |
| if provider == "qwen_code": | |
| litellm_kwargs["custom_llm_provider"] = "qwen" | |
| litellm_kwargs["model"] = model.split("/", 1)[1] | |
| for attempt in range(self.max_retries): | |
| try: | |
| lib_logger.info( | |
| f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" | |
| ) | |
| if pre_request_callback: | |
| try: | |
| await pre_request_callback(request, litellm_kwargs) | |
| except Exception as e: | |
| if self.abort_on_callback_error: | |
| raise PreRequestCallbackError( | |
| f"Pre-request callback failed: {e}" | |
| ) from e | |
| else: | |
| lib_logger.warning( | |
| f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" | |
| ) | |
| # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}") | |
| # Convert model parameters for custom providers right before LiteLLM call | |
| final_kwargs = self._convert_model_params_for_litellm( | |
| **litellm_kwargs | |
| ) | |
| response = await litellm.acompletion( | |
| **final_kwargs, | |
| logger_fn=self._litellm_logger_callback, | |
| ) | |
| lib_logger.info( | |
| f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." | |
| ) | |
| key_acquired = False | |
| stream_generator = self._safe_streaming_wrapper( | |
| response, | |
| current_cred, | |
| model, | |
| request, | |
| provider_instance, | |
| ) | |
| async for chunk in stream_generator: | |
| yield chunk | |
| return | |
| except ( | |
| StreamedAPIError, | |
| litellm.RateLimitError, | |
| httpx.HTTPStatusError, | |
| ) as e: | |
| last_exception = e | |
| # This is the final, robust handler for streamed errors. | |
| error_payload = {} | |
| cleaned_str = None | |
| # The actual exception might be wrapped in our StreamedAPIError. | |
| original_exc = getattr(e, "data", e) | |
| classified_error = classify_error( | |
| original_exc, provider=provider | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}) during litellm stream. Failing." | |
| ) | |
| raise last_exception | |
| try: | |
| # The full error JSON is in the string representation of the exception. | |
| json_str_match = re.search( | |
| r"(\{.*\})", str(original_exc), re.DOTALL | |
| ) | |
| if json_str_match: | |
| cleaned_str = codecs.decode( | |
| json_str_match.group(1), "unicode_escape" | |
| ) | |
| error_payload = json.loads(cleaned_str) | |
| except (json.JSONDecodeError, TypeError): | |
| error_payload = {} | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| raw_response_text=cleaned_str, | |
| ) | |
| error_details = error_payload.get("error", {}) | |
| error_status = error_details.get("status", "") | |
| error_message_text = error_details.get( | |
| "message", str(original_exc).split("\n")[0] | |
| ) | |
| # Record in accumulator for client reporting | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message_text | |
| ) | |
| if ( | |
| "quota" in error_message_text.lower() | |
| or "resource_exhausted" in error_status.lower() | |
| ): | |
| consecutive_quota_failures += 1 | |
| quota_value = "N/A" | |
| quota_id = "N/A" | |
| if "details" in error_details and isinstance( | |
| error_details.get("details"), list | |
| ): | |
| for detail in error_details["details"]: | |
| if isinstance(detail.get("violations"), list): | |
| for violation in detail["violations"]: | |
| if "quotaValue" in violation: | |
| quota_value = violation[ | |
| "quotaValue" | |
| ] | |
| if "quotaId" in violation: | |
| quota_id = violation["quotaId"] | |
| if ( | |
| quota_value != "N/A" | |
| and quota_id != "N/A" | |
| ): | |
| break | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| if consecutive_quota_failures >= 3: | |
| # Fatal: likely input data too large | |
| client_error_message = ( | |
| f"Request failed after 3 consecutive quota errors (input may be too large). " | |
| f"Limit: {quota_value} (Quota ID: {quota_id})" | |
| ) | |
| lib_logger.error( | |
| f"Fatal quota error for {mask_credential(current_cred)}. ID: {quota_id}, Limit: {quota_value}" | |
| ) | |
| yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| else: | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} quota error ({consecutive_quota_failures}/3). Rotating." | |
| ) | |
| break | |
| else: | |
| consecutive_quota_failures = 0 | |
| lib_logger.warning( | |
| f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating." | |
| ) | |
| if classified_error.error_type == "rate_limit": | |
| cooldown_duration = ( | |
| classified_error.retry_after or 60 | |
| ) | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| break | |
| except ( | |
| APIConnectionError, | |
| litellm.InternalServerError, | |
| litellm.ServiceUnavailableError, | |
| ) as e: | |
| consecutive_quota_failures = 0 | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message_text = str(e).split("\n")[0] | |
| # Record error in accumulator (server errors are transient, not abnormal) | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message_text | |
| ) | |
| # Provider-level error: don't increment consecutive failures | |
| await self.usage_manager.record_failure( | |
| current_cred, | |
| model, | |
| classified_error, | |
| increment_consecutive_failures=False, | |
| ) | |
| if attempt >= self.max_retries - 1: | |
| lib_logger.warning( | |
| f"Credential {mask_credential(current_cred)} failed after max retries for model {model} due to a server error. Rotating key silently." | |
| ) | |
| # [MODIFIED] Do not yield to the client here. | |
| break | |
| wait_time = classified_error.retry_after or ( | |
| 2**attempt | |
| ) + random.uniform(0, 1) | |
| remaining_budget = deadline - time.time() | |
| if wait_time > remaining_budget: | |
| lib_logger.warning( | |
| f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early." | |
| ) | |
| break | |
| lib_logger.warning( | |
| f"Credential {mask_credential(current_cred)} encountered a server error for model {model}. Reason: '{error_message_text}'. Retrying in {wait_time:.2f}s." | |
| ) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| except Exception as e: | |
| consecutive_quota_failures = 0 | |
| last_exception = e | |
| log_failure( | |
| api_key=current_cred, | |
| model=model, | |
| attempt=attempt + 1, | |
| error=e, | |
| request_headers=dict(request.headers) | |
| if request | |
| else {}, | |
| ) | |
| classified_error = classify_error(e, provider=provider) | |
| error_message_text = str(e).split("\n")[0] | |
| # Record error in accumulator | |
| error_accumulator.record_error( | |
| current_cred, classified_error, error_message_text | |
| ) | |
| lib_logger.warning( | |
| f"Credential {mask_credential(current_cred)} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}." | |
| ) | |
| # Handle rate limits with cooldown (exclude quota_exceeded) | |
| if ( | |
| classified_error.status_code == 429 | |
| and classified_error.error_type != "quota_exceeded" | |
| ) or classified_error.error_type == "rate_limit": | |
| cooldown_duration = classified_error.retry_after or 60 | |
| await self.cooldown_manager.start_cooldown( | |
| provider, cooldown_duration | |
| ) | |
| lib_logger.warning( | |
| f"Rate limit detected for {provider}. Starting {cooldown_duration}s cooldown." | |
| ) | |
| # Check if this error should trigger rotation | |
| if not should_rotate_on_error(classified_error): | |
| # Non-rotatable errors - fail immediately | |
| lib_logger.error( | |
| f"Non-recoverable error ({classified_error.error_type}). Failing request." | |
| ) | |
| raise last_exception | |
| # Record failure and rotate to next key | |
| await self.usage_manager.record_failure( | |
| current_cred, model, classified_error | |
| ) | |
| lib_logger.info( | |
| f"Rotating to next key after {classified_error.error_type} error." | |
| ) | |
| break | |
| finally: | |
| if key_acquired and current_cred: | |
| await self.usage_manager.release_key(current_cred, model) | |
| # Build detailed error response using error accumulator | |
| error_accumulator.timeout_occurred = time.time() >= deadline | |
| if error_accumulator.has_errors(): | |
| # Log concise summary for server logs | |
| lib_logger.error(error_accumulator.build_log_message()) | |
| # Build structured error response for client | |
| error_response = error_accumulator.build_client_error_response() | |
| error_data = error_response | |
| else: | |
| # Fallback if no errors were recorded (shouldn't happen) | |
| final_error_message = ( | |
| "Request failed: No available API keys after rotation or timeout." | |
| ) | |
| if last_exception: | |
| final_error_message = ( | |
| f"Request failed. Last error: {str(last_exception)}" | |
| ) | |
| error_data = { | |
| "error": {"message": final_error_message, "type": "proxy_error"} | |
| } | |
| lib_logger.error(final_error_message) | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except NoAvailableKeysError as e: | |
| lib_logger.error( | |
| f"A streaming request failed because no keys were available within the time budget: {e}" | |
| ) | |
| error_data = {"error": {"message": str(e), "type": "proxy_busy"}} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| # This will now only catch fatal errors that should be raised, like invalid requests. | |
| lib_logger.error( | |
| f"An unhandled exception occurred in streaming retry logic: {e}", | |
| exc_info=True, | |
| ) | |
| error_data = { | |
| "error": { | |
| "message": f"An unexpected error occurred: {str(e)}", | |
| "type": "proxy_internal_error", | |
| } | |
| } | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| def acompletion( | |
| self, | |
| request: Optional[Any] = None, | |
| pre_request_callback: Optional[callable] = None, | |
| **kwargs, | |
| ) -> Union[Any, AsyncGenerator[str, None]]: | |
| """ | |
| Dispatcher for completion requests. | |
| Args: | |
| request: Optional request object, used for client disconnect checks and logging. | |
| pre_request_callback: Optional async callback function to be called before each API request attempt. | |
| The callback will receive the `request` object and the prepared request `kwargs` as arguments. | |
| This can be used for custom logic such as request validation, logging, or rate limiting. | |
| If the callback raises an exception, the completion request will be aborted and the exception will propagate. | |
| Returns: | |
| The completion response object, or an async generator for streaming responses, or None if all retries fail. | |
| """ | |
| # Handle iflow provider: remove stream_options to avoid HTTP 406 | |
| model = kwargs.get("model", "") | |
| provider = model.split("/")[0] if "/" in model else "" | |
| if provider == "iflow" and "stream_options" in kwargs: | |
| lib_logger.debug( | |
| "Removing stream_options for iflow provider to avoid HTTP 406" | |
| ) | |
| kwargs.pop("stream_options", None) | |
| if kwargs.get("stream"): | |
| # Only add stream_options for providers that support it (excluding iflow) | |
| if provider != "iflow": | |
| if "stream_options" not in kwargs: | |
| kwargs["stream_options"] = {} | |
| if "include_usage" not in kwargs["stream_options"]: | |
| kwargs["stream_options"]["include_usage"] = True | |
| return self._streaming_acompletion_with_retry( | |
| request=request, pre_request_callback=pre_request_callback, **kwargs | |
| ) | |
| else: | |
| return self._execute_with_retry( | |
| litellm.acompletion, | |
| request=request, | |
| pre_request_callback=pre_request_callback, | |
| **kwargs, | |
| ) | |
| def aembedding( | |
| self, | |
| request: Optional[Any] = None, | |
| pre_request_callback: Optional[callable] = None, | |
| **kwargs, | |
| ) -> Any: | |
| """ | |
| Executes an embedding request with retry logic. | |
| Args: | |
| request: Optional request object, used for client disconnect checks and logging. | |
| pre_request_callback: Optional async callback function to be called before each API request attempt. | |
| The callback will receive the `request` object and the prepared request `kwargs` as arguments. | |
| This can be used for custom logic such as request validation, logging, or rate limiting. | |
| If the callback raises an exception, the embedding request will be aborted and the exception will propagate. | |
| Returns: | |
| The embedding response object, or None if all retries fail. | |
| """ | |
| return self._execute_with_retry( | |
| litellm.aembedding, | |
| request=request, | |
| pre_request_callback=pre_request_callback, | |
| **kwargs, | |
| ) | |
| def token_count(self, **kwargs) -> int: | |
| """Calculates the number of tokens for a given text or list of messages.""" | |
| kwargs = self._convert_model_params(**kwargs) | |
| model = kwargs.get("model") | |
| text = kwargs.get("text") | |
| messages = kwargs.get("messages") | |
| if not model: | |
| raise ValueError("'model' is a required parameter.") | |
| if messages: | |
| return token_counter(model=model, messages=messages) | |
| elif text: | |
| return token_counter(model=model, text=text) | |
| else: | |
| raise ValueError("Either 'text' or 'messages' must be provided.") | |
| async def get_available_models(self, provider: str) -> List[str]: | |
| """Returns a list of available models for a specific provider, with caching.""" | |
| lib_logger.info(f"Getting available models for provider: {provider}") | |
| if provider in self._model_list_cache: | |
| lib_logger.debug(f"Returning cached models for provider: {provider}") | |
| return self._model_list_cache[provider] | |
| credentials_for_provider = self.all_credentials.get(provider) | |
| if not credentials_for_provider: | |
| lib_logger.warning(f"No credentials for provider: {provider}") | |
| return [] | |
| # Create a copy and shuffle it to randomize the starting credential | |
| shuffled_credentials = list(credentials_for_provider) | |
| random.shuffle(shuffled_credentials) | |
| provider_instance = self._get_provider_instance(provider) | |
| if provider_instance: | |
| # For providers with hardcoded models (like gemini_cli), we only need to call once. | |
| # For others, we might need to try multiple keys if one is invalid. | |
| # The current logic of iterating works for both, as the credential is not | |
| # always used in get_models. | |
| for credential in shuffled_credentials: | |
| try: | |
| # Display last 6 chars for API keys, or the filename for OAuth paths | |
| cred_display = mask_credential(credential) | |
| lib_logger.debug( | |
| f"Attempting to get models for {provider} with credential {cred_display}" | |
| ) | |
| models = await provider_instance.get_models( | |
| credential, self.http_client | |
| ) | |
| lib_logger.info( | |
| f"Got {len(models)} models for provider: {provider}" | |
| ) | |
| # Whitelist and blacklist logic | |
| final_models = [] | |
| for m in models: | |
| is_whitelisted = self._is_model_whitelisted(provider, m) | |
| is_blacklisted = self._is_model_ignored(provider, m) | |
| if is_whitelisted: | |
| final_models.append(m) | |
| continue | |
| if not is_blacklisted: | |
| final_models.append(m) | |
| if len(final_models) != len(models): | |
| lib_logger.info( | |
| f"Filtered out {len(models) - len(final_models)} models for provider {provider}." | |
| ) | |
| self._model_list_cache[provider] = final_models | |
| return final_models | |
| except Exception as e: | |
| classified_error = classify_error(e, provider=provider) | |
| cred_display = mask_credential(credential) | |
| lib_logger.debug( | |
| f"Failed to get models for provider {provider} with credential {cred_display}: {classified_error.error_type}. Trying next credential." | |
| ) | |
| continue # Try the next credential | |
| lib_logger.error( | |
| f"Failed to get models for provider {provider} after trying all credentials." | |
| ) | |
| return [] | |
| async def get_all_available_models( | |
| self, grouped: bool = True | |
| ) -> Union[Dict[str, List[str]], List[str]]: | |
| """Returns a list of all available models, either grouped by provider or as a flat list.""" | |
| lib_logger.info("Getting all available models...") | |
| all_providers = list(self.all_credentials.keys()) | |
| tasks = [self.get_available_models(provider) for provider in all_providers] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| all_provider_models = {} | |
| for provider, result in zip(all_providers, results): | |
| if isinstance(result, Exception): | |
| lib_logger.error( | |
| f"Failed to get models for provider {provider}: {result}" | |
| ) | |
| all_provider_models[provider] = [] | |
| else: | |
| all_provider_models[provider] = result | |
| lib_logger.info("Finished getting all available models.") | |
| if grouped: | |
| return all_provider_models | |
| else: | |
| flat_models = [] | |
| for models in all_provider_models.values(): | |
| flat_models.extend(models) | |
| return flat_models | |
| async def get_quota_stats( | |
| self, | |
| provider_filter: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Get quota and usage stats for all credentials. | |
| This returns cached/disk data aggregated by provider. | |
| For provider-specific quota info (e.g., Antigravity quota groups), | |
| it enriches the data from provider plugins. | |
| Args: | |
| provider_filter: If provided, only return stats for this provider | |
| Returns: | |
| Complete stats dict ready for the /v1/quota-stats endpoint | |
| """ | |
| # Get base stats from usage manager | |
| stats = await self.usage_manager.get_stats_for_endpoint(provider_filter) | |
| # Enrich with provider-specific quota data | |
| for provider, prov_stats in stats.get("providers", {}).items(): | |
| provider_class = self._provider_plugins.get(provider) | |
| if not provider_class: | |
| continue | |
| # Get or create provider instance | |
| if provider not in self._provider_instances: | |
| self._provider_instances[provider] = provider_class() | |
| provider_instance = self._provider_instances[provider] | |
| # Check if provider has quota tracking (like Antigravity) | |
| if hasattr(provider_instance, "_get_effective_quota_groups"): | |
| # Add quota group summary | |
| quota_groups = provider_instance._get_effective_quota_groups() | |
| prov_stats["quota_groups"] = {} | |
| for group_name, group_models in quota_groups.items(): | |
| group_stats = { | |
| "models": group_models, | |
| "credentials_total": 0, | |
| "credentials_exhausted": 0, | |
| "avg_remaining_pct": 0, | |
| "total_remaining_pcts": [], | |
| # Total requests tracking across all credentials | |
| "total_requests_used": 0, | |
| "total_requests_max": 0, | |
| # Tier breakdown: tier_name -> {"total": N, "active": M} | |
| "tiers": {}, | |
| } | |
| # Calculate per-credential quota for this group | |
| for cred in prov_stats.get("credentials", []): | |
| models_data = cred.get("models", {}) | |
| group_stats["credentials_total"] += 1 | |
| # Track tier - get directly from provider cache since cred["tier"] not set yet | |
| tier = cred.get("tier") | |
| if not tier and hasattr( | |
| provider_instance, "project_tier_cache" | |
| ): | |
| cred_path = cred.get("full_path", "") | |
| tier = provider_instance.project_tier_cache.get(cred_path) | |
| tier = tier or "unknown" | |
| # Initialize tier entry if needed with priority for sorting | |
| if tier not in group_stats["tiers"]: | |
| priority = 10 # default | |
| if hasattr(provider_instance, "_resolve_tier_priority"): | |
| priority = provider_instance._resolve_tier_priority( | |
| tier | |
| ) | |
| group_stats["tiers"][tier] = { | |
| "total": 0, | |
| "active": 0, | |
| "priority": priority, | |
| } | |
| group_stats["tiers"][tier]["total"] += 1 | |
| # Find model with VALID baseline (not just any model with stats) | |
| model_stats = None | |
| for model in group_models: | |
| candidate = self._find_model_stats_in_data( | |
| models_data, model, provider, provider_instance | |
| ) | |
| if candidate: | |
| baseline = candidate.get("baseline_remaining_fraction") | |
| if baseline is not None: | |
| model_stats = candidate | |
| break | |
| # Keep first found as fallback (for request counts) | |
| if model_stats is None: | |
| model_stats = candidate | |
| if model_stats: | |
| baseline = model_stats.get("baseline_remaining_fraction") | |
| req_count = model_stats.get("request_count", 0) | |
| max_req = model_stats.get("quota_max_requests") or 0 | |
| # Accumulate totals (one model per group per credential) | |
| group_stats["total_requests_used"] += req_count | |
| group_stats["total_requests_max"] += max_req | |
| if baseline is not None: | |
| remaining_pct = int(baseline * 100) | |
| group_stats["total_remaining_pcts"].append( | |
| remaining_pct | |
| ) | |
| if baseline <= 0: | |
| group_stats["credentials_exhausted"] += 1 | |
| else: | |
| # Credential is active (has quota remaining) | |
| group_stats["tiers"][tier]["active"] += 1 | |
| # Calculate average remaining percentage (per-credential average) | |
| if group_stats["total_remaining_pcts"]: | |
| group_stats["avg_remaining_pct"] = int( | |
| sum(group_stats["total_remaining_pcts"]) | |
| / len(group_stats["total_remaining_pcts"]) | |
| ) | |
| del group_stats["total_remaining_pcts"] | |
| # Calculate total remaining percentage (global) | |
| if group_stats["total_requests_max"] > 0: | |
| used = group_stats["total_requests_used"] | |
| max_r = group_stats["total_requests_max"] | |
| group_stats["total_remaining_pct"] = max( | |
| 0, int((1 - used / max_r) * 100) | |
| ) | |
| else: | |
| group_stats["total_remaining_pct"] = None | |
| prov_stats["quota_groups"][group_name] = group_stats | |
| # Also enrich each credential with formatted quota group info | |
| for cred in prov_stats.get("credentials", []): | |
| cred["model_groups"] = {} | |
| models_data = cred.get("models", {}) | |
| for group_name, group_models in quota_groups.items(): | |
| # Find model with VALID baseline (prefer over any model with stats) | |
| # Also track the best reset_ts across all models in the group | |
| model_stats = None | |
| best_reset_ts = None | |
| for model in group_models: | |
| candidate = self._find_model_stats_in_data( | |
| models_data, model, provider, provider_instance | |
| ) | |
| if candidate: | |
| # Track the best (latest) reset_ts from any model in group | |
| candidate_reset_ts = candidate.get("quota_reset_ts") | |
| if candidate_reset_ts: | |
| if ( | |
| best_reset_ts is None | |
| or candidate_reset_ts > best_reset_ts | |
| ): | |
| best_reset_ts = candidate_reset_ts | |
| baseline = candidate.get("baseline_remaining_fraction") | |
| if baseline is not None: | |
| model_stats = candidate | |
| # Don't break - continue to find best reset_ts | |
| # Keep first found as fallback | |
| if model_stats is None: | |
| model_stats = candidate | |
| if model_stats: | |
| baseline = model_stats.get("baseline_remaining_fraction") | |
| max_req = model_stats.get("quota_max_requests") | |
| req_count = model_stats.get("request_count", 0) | |
| # Use best_reset_ts from any model in the group | |
| reset_ts = best_reset_ts or model_stats.get( | |
| "quota_reset_ts" | |
| ) | |
| remaining_pct = ( | |
| int(baseline * 100) if baseline is not None else None | |
| ) | |
| is_exhausted = baseline is not None and baseline <= 0 | |
| # Format reset time | |
| reset_iso = None | |
| if reset_ts: | |
| try: | |
| from datetime import datetime, timezone | |
| reset_iso = datetime.fromtimestamp( | |
| reset_ts, tz=timezone.utc | |
| ).isoformat() | |
| except (ValueError, OSError): | |
| pass | |
| cred["model_groups"][group_name] = { | |
| "remaining_pct": remaining_pct, | |
| "requests_used": req_count, | |
| "requests_max": max_req, | |
| "display": f"{req_count}/{max_req}" | |
| if max_req | |
| else f"{req_count}/?", | |
| "is_exhausted": is_exhausted, | |
| "reset_time_iso": reset_iso, | |
| "models": group_models, | |
| "confidence": self._get_baseline_confidence( | |
| model_stats | |
| ), | |
| } | |
| # Recalculate credential's requests from model_groups | |
| # This fixes double-counting when models share quota groups | |
| if cred.get("model_groups"): | |
| group_requests = sum( | |
| g.get("requests_used", 0) | |
| for g in cred["model_groups"].values() | |
| ) | |
| cred["requests"] = group_requests | |
| # HACK: Fix global requests if present | |
| # This is a simplified fix that sets global.requests = current group_requests. | |
| # TODO: Properly track archived requests per quota group in usage_manager.py | |
| # so that global stats correctly sum: current_period + archived_periods | |
| # without double-counting models that share quota groups. | |
| # See: usage_manager.py lines 2388-2404 where global stats are built | |
| # by iterating all models (causing double-counting for grouped models). | |
| if cred.get("global"): | |
| cred["global"]["requests"] = group_requests | |
| # Try to get email from provider's cache | |
| cred_path = cred.get("full_path", "") | |
| if hasattr(provider_instance, "project_tier_cache"): | |
| tier = provider_instance.project_tier_cache.get(cred_path) | |
| if tier: | |
| cred["tier"] = tier | |
| return stats | |
| def _find_model_stats_in_data( | |
| self, | |
| models_data: Dict[str, Any], | |
| model: str, | |
| provider: str, | |
| provider_instance: Any, | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Find model stats in models_data, trying various name variants. | |
| Handles aliased model names (e.g., gemini-3-pro-preview -> gemini-3-pro-high) | |
| by using the provider's _user_to_api_model() mapping. | |
| Args: | |
| models_data: Dict of model_name -> stats from credential | |
| model: Model name to look up (user-facing name) | |
| provider: Provider name for prefixing | |
| provider_instance: Provider instance for alias methods | |
| Returns: | |
| Model stats dict if found, None otherwise | |
| """ | |
| # Try direct match with and without provider prefix | |
| prefixed_model = f"{provider}/{model}" | |
| model_stats = models_data.get(prefixed_model) or models_data.get(model) | |
| if model_stats: | |
| return model_stats | |
| # Try with API model name (e.g., gemini-3-pro-preview -> gemini-3-pro-high) | |
| if hasattr(provider_instance, "_user_to_api_model"): | |
| api_model = provider_instance._user_to_api_model(model) | |
| if api_model != model: | |
| prefixed_api = f"{provider}/{api_model}" | |
| model_stats = models_data.get(prefixed_api) or models_data.get( | |
| api_model | |
| ) | |
| return model_stats | |
| def _get_baseline_confidence(self, model_stats: Dict) -> str: | |
| """ | |
| Determine confidence level based on baseline age. | |
| Args: | |
| model_stats: Model statistics dict with baseline_fetched_at | |
| Returns: | |
| "high" | "medium" | "low" | |
| """ | |
| baseline_fetched_at = model_stats.get("baseline_fetched_at") | |
| if not baseline_fetched_at: | |
| return "low" | |
| age_seconds = time.time() - baseline_fetched_at | |
| if age_seconds < 300: # 5 minutes | |
| return "high" | |
| elif age_seconds < 1800: # 30 minutes | |
| return "medium" | |
| return "low" | |
| async def reload_usage_from_disk(self) -> None: | |
| """ | |
| Force reload usage data from disk. | |
| Useful when wanting fresh stats without making external API calls. | |
| """ | |
| await self.usage_manager.reload_from_disk() | |
| async def force_refresh_quota( | |
| self, | |
| provider: Optional[str] = None, | |
| credential: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Force refresh quota from external API. | |
| For Antigravity, this fetches live quota data from the API. | |
| For other providers, this is a no-op (just reloads from disk). | |
| Args: | |
| provider: If specified, only refresh this provider | |
| credential: If specified, only refresh this specific credential | |
| Returns: | |
| Refresh result dict with success/failure info | |
| """ | |
| result = { | |
| "action": "force_refresh", | |
| "scope": "credential" | |
| if credential | |
| else ("provider" if provider else "all"), | |
| "provider": provider, | |
| "credential": credential, | |
| "credentials_refreshed": 0, | |
| "success_count": 0, | |
| "failed_count": 0, | |
| "duration_ms": 0, | |
| "errors": [], | |
| } | |
| start_time = time.time() | |
| # Determine which providers to refresh | |
| if provider: | |
| providers_to_refresh = ( | |
| [provider] if provider in self.all_credentials else [] | |
| ) | |
| else: | |
| providers_to_refresh = list(self.all_credentials.keys()) | |
| for prov in providers_to_refresh: | |
| provider_class = self._provider_plugins.get(prov) | |
| if not provider_class: | |
| continue | |
| # Get or create provider instance | |
| if prov not in self._provider_instances: | |
| self._provider_instances[prov] = provider_class() | |
| provider_instance = self._provider_instances[prov] | |
| # Check if provider supports quota refresh (like Antigravity) | |
| if hasattr(provider_instance, "fetch_initial_baselines"): | |
| # Get credentials to refresh | |
| if credential: | |
| # Find full path for this credential | |
| creds_to_refresh = [] | |
| for cred_path in self.all_credentials.get(prov, []): | |
| if cred_path.endswith(credential) or cred_path == credential: | |
| creds_to_refresh.append(cred_path) | |
| break | |
| else: | |
| creds_to_refresh = self.all_credentials.get(prov, []) | |
| if not creds_to_refresh: | |
| continue | |
| try: | |
| # Fetch live quota from API for ALL specified credentials | |
| quota_results = await provider_instance.fetch_initial_baselines( | |
| creds_to_refresh | |
| ) | |
| # Store baselines in usage manager | |
| if hasattr(provider_instance, "_store_baselines_to_usage_manager"): | |
| stored = ( | |
| await provider_instance._store_baselines_to_usage_manager( | |
| quota_results, self.usage_manager | |
| ) | |
| ) | |
| result["success_count"] += stored | |
| result["credentials_refreshed"] += len(creds_to_refresh) | |
| # Count failures | |
| for cred_path, data in quota_results.items(): | |
| if data.get("status") != "success": | |
| result["failed_count"] += 1 | |
| result["errors"].append( | |
| f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" | |
| ) | |
| except Exception as e: | |
| lib_logger.error(f"Failed to refresh quota for {prov}: {e}") | |
| result["errors"].append(f"{prov}: {str(e)}") | |
| result["failed_count"] += len(creds_to_refresh) | |
| result["duration_ms"] = int((time.time() - start_time) * 1000) | |
| return result | |