import asyncio import fnmatch import json import re import codecs import time import os import random import httpx import litellm from litellm.exceptions import APIConnectionError from litellm.litellm_core_utils.token_counter import token_counter import logging from pathlib import Path from typing import List, Dict, Any, AsyncGenerator, Optional, Union lib_logger = logging.getLogger("rotator_library") # Ensure the logger is configured to propagate to the root logger # which is set up in main.py. This allows the main app to control # log levels and handlers centrally. lib_logger.propagate = False from .usage_manager import UsageManager from .failure_logger import log_failure, configure_failure_logger from .error_handler import ( PreRequestCallbackError, CredentialNeedsReauthError, classify_error, AllProviders, NoAvailableKeysError, should_rotate_on_error, should_retry_same_key, RequestErrorAccumulator, mask_credential, ) from .providers import PROVIDER_PLUGINS from .providers.openai_compatible_provider import OpenAICompatibleProvider from .request_sanitizer import sanitize_request_payload from .cooldown_manager import CooldownManager from .credential_manager import CredentialManager from .background_refresher import BackgroundRefresher from .model_definitions import ModelDefinitions from .utils.paths import get_default_root, get_logs_dir, get_oauth_dir, get_data_file class StreamedAPIError(Exception): """Custom exception to signal an API error received over a stream.""" def __init__(self, message, data=None): super().__init__(message) self.data = data class RotatingClient: """ A client that intelligently rotates and retries API keys using LiteLLM, with support for both streaming and non-streaming responses. """ def __init__( self, api_keys: Optional[Dict[str, List[str]]] = None, oauth_credentials: Optional[Dict[str, List[str]]] = None, max_retries: int = 2, usage_file_path: Optional[Union[str, Path]] = None, configure_logging: bool = True, global_timeout: int = 30, abort_on_callback_error: bool = True, litellm_provider_params: Optional[Dict[str, Any]] = None, ignore_models: Optional[Dict[str, List[str]]] = None, whitelist_models: Optional[Dict[str, List[str]]] = None, enable_request_logging: bool = False, max_concurrent_requests_per_key: Optional[Dict[str, int]] = None, rotation_tolerance: float = 3.0, data_dir: Optional[Union[str, Path]] = None, ): """ Initialize the RotatingClient with intelligent credential rotation. Args: api_keys: Dictionary mapping provider names to lists of API keys oauth_credentials: Dictionary mapping provider names to OAuth credential paths max_retries: Maximum number of retry attempts per credential usage_file_path: Path to store usage statistics. If None, uses data_dir/key_usage.json configure_logging: Whether to configure library logging global_timeout: Global timeout for requests in seconds abort_on_callback_error: Whether to abort on pre-request callback errors litellm_provider_params: Provider-specific parameters for LiteLLM ignore_models: Models to ignore/blacklist per provider whitelist_models: Models to explicitly whitelist per provider enable_request_logging: Whether to enable detailed request logging max_concurrent_requests_per_key: Max concurrent requests per key by provider rotation_tolerance: Tolerance for weighted random credential rotation. - 0.0: Deterministic, least-used credential always selected - 2.0 - 4.0 (default, recommended): Balanced randomness, can pick credentials within 2 uses of max - 5.0+: High randomness, more unpredictable selection patterns data_dir: Root directory for all data files (logs, cache, oauth_creds, key_usage.json). If None, auto-detects: EXE directory if frozen, else current working directory. """ # Resolve data_dir early - this becomes the root for all file operations if data_dir is not None: self.data_dir = Path(data_dir).resolve() else: self.data_dir = get_default_root() # Configure failure logger to use correct logs directory configure_failure_logger(get_logs_dir(self.data_dir)) os.environ["LITELLM_LOG"] = "ERROR" litellm.set_verbose = False litellm.drop_params = True if configure_logging: # When True, this allows logs from this library to be handled # by the parent application's logging configuration. lib_logger.propagate = True # Remove any default handlers to prevent duplicate logging if lib_logger.hasHandlers(): lib_logger.handlers.clear() lib_logger.addHandler(logging.NullHandler()) else: lib_logger.propagate = False api_keys = api_keys or {} oauth_credentials = oauth_credentials or {} # Filter out providers with empty lists of credentials to ensure validity api_keys = {provider: keys for provider, keys in api_keys.items() if keys} oauth_credentials = { provider: paths for provider, paths in oauth_credentials.items() if paths } if not api_keys and not oauth_credentials: lib_logger.warning( "No provider credentials configured. The client will be unable to make any API requests." ) self.api_keys = api_keys # Use provided oauth_credentials directly if available (already discovered by main.py) # Only call discover_and_prepare() if no credentials were passed if oauth_credentials: self.oauth_credentials = oauth_credentials else: self.credential_manager = CredentialManager( os.environ, oauth_dir=get_oauth_dir(self.data_dir) ) self.oauth_credentials = self.credential_manager.discover_and_prepare() self.background_refresher = BackgroundRefresher(self) self.oauth_providers = set(self.oauth_credentials.keys()) all_credentials = {} for provider, keys in api_keys.items(): all_credentials.setdefault(provider, []).extend(keys) for provider, paths in self.oauth_credentials.items(): all_credentials.setdefault(provider, []).extend(paths) self.all_credentials = all_credentials self.max_retries = max_retries self.global_timeout = global_timeout self.abort_on_callback_error = abort_on_callback_error # Initialize provider plugins early so they can be used for rotation mode detection self._provider_plugins = PROVIDER_PLUGINS self._provider_instances = {} # Build provider rotation modes map # Each provider can specify its preferred rotation mode ("balanced" or "sequential") provider_rotation_modes = {} for provider in self.all_credentials.keys(): provider_class = self._provider_plugins.get(provider) if provider_class and hasattr(provider_class, "get_rotation_mode"): # Use class method to get rotation mode (checks env var + class default) mode = provider_class.get_rotation_mode(provider) else: # Fallback: check environment variable directly env_key = f"ROTATION_MODE_{provider.upper()}" mode = os.getenv(env_key, "balanced") provider_rotation_modes[provider] = mode if mode != "balanced": lib_logger.info(f"Provider '{provider}' using rotation mode: {mode}") # Build priority-based concurrency multiplier maps # These are universal multipliers based on credential tier/priority priority_multipliers: Dict[str, Dict[int, int]] = {} priority_multipliers_by_mode: Dict[str, Dict[str, Dict[int, int]]] = {} sequential_fallback_multipliers: Dict[str, int] = {} for provider in self.all_credentials.keys(): provider_class = self._provider_plugins.get(provider) # Start with provider class defaults if provider_class: # Get default priority multipliers from provider class if hasattr(provider_class, "default_priority_multipliers"): default_multipliers = provider_class.default_priority_multipliers if default_multipliers: priority_multipliers[provider] = dict(default_multipliers) # Get sequential fallback from provider class if hasattr(provider_class, "default_sequential_fallback_multiplier"): fallback = provider_class.default_sequential_fallback_multiplier if fallback != 1: # Only store if different from global default sequential_fallback_multipliers[provider] = fallback # Override with environment variables # Format: CONCURRENCY_MULTIPLIER__PRIORITY_= # Format: CONCURRENCY_MULTIPLIER__PRIORITY__= for key, value in os.environ.items(): prefix = f"CONCURRENCY_MULTIPLIER_{provider.upper()}_PRIORITY_" if key.startswith(prefix): remainder = key[len(prefix) :] try: multiplier = int(value) if multiplier < 1: lib_logger.warning(f"Invalid {key}: {value}. Must be >= 1.") continue # Check if mode-specific (e.g., _PRIORITY_1_SEQUENTIAL) if "_" in remainder: parts = remainder.rsplit("_", 1) priority = int(parts[0]) mode = parts[1].lower() if mode in ("sequential", "balanced"): # Mode-specific override if provider not in priority_multipliers_by_mode: priority_multipliers_by_mode[provider] = {} if mode not in priority_multipliers_by_mode[provider]: priority_multipliers_by_mode[provider][mode] = {} priority_multipliers_by_mode[provider][mode][ priority ] = multiplier lib_logger.info( f"Provider '{provider}' priority {priority} ({mode} mode) multiplier: {multiplier}x" ) else: # Assume it's part of the priority number (unlikely but handle gracefully) lib_logger.warning(f"Unknown mode in {key}: {mode}") else: # Universal priority multiplier priority = int(remainder) if provider not in priority_multipliers: priority_multipliers[provider] = {} priority_multipliers[provider][priority] = multiplier lib_logger.info( f"Provider '{provider}' priority {priority} multiplier: {multiplier}x" ) except ValueError: lib_logger.warning( f"Invalid {key}: {value}. Could not parse priority or multiplier." ) # Log configured multipliers for provider, multipliers in priority_multipliers.items(): if multipliers: lib_logger.info( f"Provider '{provider}' priority multipliers: {multipliers}" ) for provider, fallback in sequential_fallback_multipliers.items(): lib_logger.info( f"Provider '{provider}' sequential fallback multiplier: {fallback}x" ) # Resolve usage file path - use provided path or default to data_dir if usage_file_path is not None: resolved_usage_path = Path(usage_file_path) else: resolved_usage_path = self.data_dir / "key_usage.json" self.usage_manager = UsageManager( file_path=resolved_usage_path, rotation_tolerance=rotation_tolerance, provider_rotation_modes=provider_rotation_modes, provider_plugins=PROVIDER_PLUGINS, priority_multipliers=priority_multipliers, priority_multipliers_by_mode=priority_multipliers_by_mode, sequential_fallback_multipliers=sequential_fallback_multipliers, ) self._model_list_cache = {} self.http_client = httpx.AsyncClient() self.all_providers = AllProviders() self.cooldown_manager = CooldownManager() self.litellm_provider_params = litellm_provider_params or {} self.ignore_models = ignore_models or {} self.whitelist_models = whitelist_models or {} self.enable_request_logging = enable_request_logging self.model_definitions = ModelDefinitions() # Store and validate max concurrent requests per key self.max_concurrent_requests_per_key = max_concurrent_requests_per_key or {} # Validate all values are >= 1 for provider, max_val in self.max_concurrent_requests_per_key.items(): if max_val < 1: lib_logger.warning( f"Invalid max_concurrent for '{provider}': {max_val}. Setting to 1." ) self.max_concurrent_requests_per_key[provider] = 1 def _is_model_ignored(self, provider: str, model_id: str) -> bool: """ Checks if a model should be ignored based on the ignore list. Supports full glob/fnmatch patterns for both full model IDs and model names. Pattern examples: - "gpt-4" - exact match - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - "*-preview*" - contains wildcard (matches anything with -preview) - "*" - match all """ model_provider = model_id.split("/")[0] if model_provider not in self.ignore_models: return False ignore_list = self.ignore_models[model_provider] if ignore_list == ["*"]: return True try: # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") provider_model_name = model_id.split("/", 1)[1] except IndexError: provider_model_name = model_id for ignored_pattern in ignore_list: # Use fnmatch for full glob pattern support if fnmatch.fnmatch(provider_model_name, ignored_pattern) or fnmatch.fnmatch( model_id, ignored_pattern ): return True return False def _is_model_whitelisted(self, provider: str, model_id: str) -> bool: """ Checks if a model is explicitly whitelisted. Supports full glob/fnmatch patterns for both full model IDs and model names. Pattern examples: - "gpt-4" - exact match - "gpt-4*" - prefix wildcard (matches gpt-4, gpt-4-turbo, etc.) - "*-preview" - suffix wildcard (matches gpt-4-preview, o1-preview, etc.) - "*-preview*" - contains wildcard (matches anything with -preview) - "*" - match all """ model_provider = model_id.split("/")[0] if model_provider not in self.whitelist_models: return False whitelist = self.whitelist_models[model_provider] try: # This is the model name as the provider sees it (e.g., "gpt-4" or "google/gemma-7b") provider_model_name = model_id.split("/", 1)[1] except IndexError: provider_model_name = model_id for whitelisted_pattern in whitelist: # Use fnmatch for full glob pattern support if fnmatch.fnmatch( provider_model_name, whitelisted_pattern ) or fnmatch.fnmatch(model_id, whitelisted_pattern): return True return False def _sanitize_litellm_log(self, log_data: dict) -> dict: """ Recursively removes large data fields and sensitive information from litellm log dictionaries to keep debug logs clean and secure. """ if not isinstance(log_data, dict): return log_data # Keys to remove at any level of the dictionary keys_to_pop = [ "messages", "input", "response", "data", "api_key", "api_base", "original_response", "additional_args", ] # Keys that might contain nested dictionaries to clean nested_keys = ["kwargs", "litellm_params", "model_info", "proxy_server_request"] # Create a deep copy to avoid modifying the original log object in memory clean_data = json.loads(json.dumps(log_data, default=str)) def clean_recursively(data_dict): if not isinstance(data_dict, dict): return # Remove sensitive/large keys for key in keys_to_pop: data_dict.pop(key, None) # Recursively clean nested dictionaries for key in nested_keys: if key in data_dict and isinstance(data_dict[key], dict): clean_recursively(data_dict[key]) # Also iterate through all values to find any other nested dicts for key, value in list(data_dict.items()): if isinstance(value, dict): clean_recursively(value) clean_recursively(clean_data) return clean_data def _litellm_logger_callback(self, log_data: dict): """ Callback function to redirect litellm's logs to the library's logger. This allows us to control the log level and destination of litellm's output. It also cleans up error logs for better readability in debug files. """ # Filter out verbose pre_api_call and post_api_call logs log_event_type = log_data.get("log_event_type") if log_event_type in ["pre_api_call", "post_api_call"]: return # Skip these verbose logs entirely # For successful calls or pre-call logs, a simple debug message is enough. if not log_data.get("exception"): sanitized_log = self._sanitize_litellm_log(log_data) # We log it at the DEBUG level to ensure it goes to the debug file # and not the console, based on the main.py configuration. lib_logger.debug(f"LiteLLM Log: {sanitized_log}") return # For failures, extract key info to make debug logs more readable. model = log_data.get("model", "N/A") call_id = log_data.get("litellm_call_id", "N/A") error_info = log_data.get("standard_logging_object", {}).get( "error_information", {} ) error_class = error_info.get("error_class", "UnknownError") error_message = error_info.get( "error_message", str(log_data.get("exception", "")) ) error_message = " ".join(error_message.split()) # Sanitize lib_logger.debug( f"LiteLLM Callback Handled Error: Model={model} | " f"Type={error_class} | Message='{error_message}'" ) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() async def close(self): """Close the HTTP client to prevent resource leaks.""" if hasattr(self, "http_client") and self.http_client: await self.http_client.aclose() def _convert_model_params(self, **kwargs) -> Dict[str, Any]: """ Converts model parameters for specific providers. For example, the 'chutes' provider requires the model name to be prepended with 'openai/' and a specific 'api_base'. """ model = kwargs.get("model") if not model: return kwargs provider = model.split("/")[0] if provider == "chutes": kwargs["model"] = f"openai/{model.split('/', 1)[1]}" kwargs["api_base"] = "https://llm.chutes.ai/v1" return kwargs def _convert_model_params_for_litellm(self, **kwargs) -> Dict[str, Any]: """ Converts model parameters specifically for LiteLLM calls. This is called right before calling LiteLLM to handle custom providers. """ model = kwargs.get("model") if not model: return kwargs provider = model.split("/")[0] # Handle custom OpenAI-compatible providers # Check if this is a custom provider by looking for API_BASE environment variable import os api_base_env = f"{provider.upper()}_API_BASE" if os.getenv(api_base_env): # For custom providers, tell LiteLLM to use openai provider with custom model name # This preserves original model name in logs but converts for LiteLLM kwargs = kwargs.copy() # Don't modify original kwargs["model"] = f"openai/{model.split('/', 1)[1]}" kwargs["api_base"] = os.getenv(api_base_env).rstrip("/") kwargs["custom_llm_provider"] = "openai" return kwargs def _apply_default_safety_settings( self, litellm_kwargs: Dict[str, Any], provider: str ): """ Ensure default Gemini safety settings are present when calling the Gemini provider. This will not override any explicit settings provided by the request. It accepts either OpenAI-compatible generic `safety_settings` (dict) or direct Gemini-style `safetySettings` (list of dicts). Missing categories will be added with safe defaults. """ if provider != "gemini": return # Generic defaults (openai-compatible style) default_generic = { "harassment": "OFF", "hate_speech": "OFF", "sexually_explicit": "OFF", "dangerous_content": "OFF", "civic_integrity": "BLOCK_NONE", } # Gemini defaults (direct Gemini format) default_gemini = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, {"category": "HARM_CATEGORY_CIVIC_INTEGRITY", "threshold": "BLOCK_NONE"}, ] # If generic form is present, ensure missing generic keys are filled in if "safety_settings" in litellm_kwargs and isinstance( litellm_kwargs["safety_settings"], dict ): for k, v in default_generic.items(): if k not in litellm_kwargs["safety_settings"]: litellm_kwargs["safety_settings"][k] = v return # If Gemini form is present, ensure missing gemini categories are appended if "safetySettings" in litellm_kwargs and isinstance( litellm_kwargs["safetySettings"], list ): present = { item.get("category") for item in litellm_kwargs["safetySettings"] if isinstance(item, dict) } for d in default_gemini: if d["category"] not in present: litellm_kwargs["safetySettings"].append(d) return # Neither present: set generic defaults so provider conversion will translate them if ( "safety_settings" not in litellm_kwargs and "safetySettings" not in litellm_kwargs ): litellm_kwargs["safety_settings"] = default_generic.copy() def get_oauth_credentials(self) -> Dict[str, List[str]]: return self.oauth_credentials def _is_custom_openai_compatible_provider(self, provider_name: str) -> bool: """Checks if a provider is a custom OpenAI-compatible provider.""" import os # Check if the provider has an API_BASE environment variable api_base_env = f"{provider_name.upper()}_API_BASE" return os.getenv(api_base_env) is not None def _get_provider_instance(self, provider_name: str): """ Lazily initializes and returns a provider instance. Only initializes providers that have configured credentials. Args: provider_name: The name of the provider to get an instance for. For OAuth providers, this may include "_oauth" suffix (e.g., "antigravity_oauth"), but credentials are stored under the base name (e.g., "antigravity"). Returns: Provider instance if credentials exist, None otherwise. """ # For OAuth providers, credentials are stored under base name (without _oauth suffix) # e.g., "antigravity_oauth" plugin → credentials under "antigravity" credential_key = provider_name if provider_name.endswith("_oauth"): base_name = provider_name[:-6] # Remove "_oauth" if base_name in self.oauth_providers: credential_key = base_name # Only initialize providers for which we have credentials if credential_key not in self.all_credentials: lib_logger.debug( f"Skipping provider '{provider_name}' initialization: no credentials configured" ) return None if provider_name not in self._provider_instances: if provider_name in self._provider_plugins: self._provider_instances[provider_name] = self._provider_plugins[ provider_name ]() elif self._is_custom_openai_compatible_provider(provider_name): # Create a generic OpenAI-compatible provider for custom providers try: self._provider_instances[provider_name] = OpenAICompatibleProvider( provider_name ) except ValueError: # If the provider doesn't have the required environment variables, treat it as a standard provider return None else: return None return self._provider_instances[provider_name] def _resolve_model_id(self, model: str, provider: str) -> str: """ Resolves the actual model ID to send to the provider. For custom models with name/ID mappings, returns the ID. Otherwise, returns the model name unchanged. Args: model: Full model string with provider (e.g., "iflow/DS-v3.2") provider: Provider name (e.g., "iflow") Returns: Full model string with ID (e.g., "iflow/deepseek-v3.2") """ # Extract model name from "provider/model_name" format model_name = model.split("/")[-1] if "/" in model else model # Try to get provider instance to check for model definitions provider_plugin = self._get_provider_instance(provider) # Check if provider has model definitions if provider_plugin and hasattr(provider_plugin, "model_definitions"): model_id = provider_plugin.model_definitions.get_model_id( provider, model_name ) if model_id and model_id != model_name: # Return with provider prefix return f"{provider}/{model_id}" # Fallback: use client's own model definitions model_id = self.model_definitions.get_model_id(provider, model_name) if model_id and model_id != model_name: return f"{provider}/{model_id}" # No conversion needed, return original return model async def _safe_streaming_wrapper( self, stream: Any, key: str, model: str, request: Optional[Any] = None, provider_plugin: Optional[Any] = None, ) -> AsyncGenerator[Any, None]: """ A hybrid wrapper for streaming that buffers fragmented JSON, handles client disconnections gracefully, and distinguishes between content and streamed errors. FINISH_REASON HANDLING: Providers just translate chunks - this wrapper handles ALL finish_reason logic: 1. Strip finish_reason from intermediate chunks (litellm defaults to "stop") 2. Track accumulated_finish_reason with priority: tool_calls > length/content_filter > stop 3. Only emit finish_reason on final chunk (detected by usage.completion_tokens > 0) """ last_usage = None stream_completed = False stream_iterator = stream.__aiter__() json_buffer = "" accumulated_finish_reason = None # Track strongest finish_reason across chunks has_tool_calls = False # Track if ANY tool calls were seen in stream try: while True: if request and await request.is_disconnected(): lib_logger.info( f"Client disconnected. Aborting stream for credential {mask_credential(key)}." ) break try: chunk = await stream_iterator.__anext__() if json_buffer: lib_logger.warning( f"Discarding incomplete JSON buffer from previous chunk: {json_buffer}" ) json_buffer = "" # Convert chunk to dict, handling both litellm.ModelResponse and raw dicts if hasattr(chunk, "dict"): chunk_dict = chunk.dict() elif hasattr(chunk, "model_dump"): chunk_dict = chunk.model_dump() else: chunk_dict = chunk # === FINISH_REASON LOGIC === # Providers send raw chunks without finish_reason logic. # This wrapper determines finish_reason based on accumulated state. if "choices" in chunk_dict and chunk_dict["choices"]: choice = chunk_dict["choices"][0] delta = choice.get("delta", {}) usage = chunk_dict.get("usage", {}) # Track tool_calls across ALL chunks - if we ever see one, finish_reason must be tool_calls if delta.get("tool_calls"): has_tool_calls = True accumulated_finish_reason = "tool_calls" # Detect final chunk: has usage with completion_tokens > 0 has_completion_tokens = ( usage and isinstance(usage, dict) and usage.get("completion_tokens", 0) > 0 ) if has_completion_tokens: # FINAL CHUNK: Determine correct finish_reason if has_tool_calls: # Tool calls always win choice["finish_reason"] = "tool_calls" elif accumulated_finish_reason: # Use accumulated reason (length, content_filter, etc.) choice["finish_reason"] = accumulated_finish_reason else: # Default to stop choice["finish_reason"] = "stop" else: # INTERMEDIATE CHUNK: Never emit finish_reason # (litellm.ModelResponse defaults to "stop" which is wrong) choice["finish_reason"] = None yield f"data: {json.dumps(chunk_dict)}\n\n" if hasattr(chunk, "usage") and chunk.usage: last_usage = chunk.usage except StopAsyncIteration: stream_completed = True if json_buffer: lib_logger.info( f"Stream ended with incomplete data in buffer: {json_buffer}" ) if last_usage: # Create a dummy ModelResponse for recording (only usage matters) dummy_response = litellm.ModelResponse(usage=last_usage) await self.usage_manager.record_success( key, model, dummy_response ) else: # If no usage seen (rare), record success without tokens/cost await self.usage_manager.record_success(key, model) break except CredentialNeedsReauthError as e: # This credential needs re-authentication but re-auth is already queued. # Wrap it so the outer retry loop can rotate to the next credential. # No scary traceback needed - this is an expected recovery scenario. raise StreamedAPIError("Credential needs re-authentication", data=e) except ( litellm.RateLimitError, litellm.ServiceUnavailableError, litellm.InternalServerError, APIConnectionError, httpx.HTTPStatusError, ) as e: # This is a critical, typed error from litellm or httpx that signals a key failure. # We do not try to parse it here. We wrap it and raise it immediately # for the outer retry loop to handle. lib_logger.warning( f"Caught a critical API error mid-stream: {type(e).__name__}. Signaling for credential rotation." ) raise StreamedAPIError("Provider error received in stream", data=e) except Exception as e: try: raw_chunk = "" # Google streams errors inside a bytes representation (b'{...}'). # We use regex to extract the content, which is more reliable than splitting. match = re.search(r"b'(\{.*\})'", str(e), re.DOTALL) if match: # The extracted string is unicode-escaped (e.g., '\\n'). We must decode it. raw_chunk = codecs.decode(match.group(1), "unicode_escape") else: # Fallback for other potential error formats that use "Received chunk:". chunk_from_split = ( str(e).split("Received chunk:")[-1].strip() ) if chunk_from_split != str( e ): # Ensure the split actually did something raw_chunk = chunk_from_split if not raw_chunk: # If we could not extract a valid chunk, we cannot proceed with reassembly. # This indicates a different, unexpected error type. Re-raise it. raise e # Append the clean chunk to the buffer and try to parse. json_buffer += raw_chunk parsed_data = json.loads(json_buffer) # If parsing succeeds, we have the complete object. lib_logger.info( f"Successfully reassembled JSON from stream: {json_buffer}" ) # Wrap the complete error object and raise it. The outer function will decide how to handle it. raise StreamedAPIError( "Provider error received in stream", data=parsed_data ) except json.JSONDecodeError: # This is the expected outcome if the JSON in the buffer is not yet complete. lib_logger.info( f"Buffer still incomplete. Waiting for more chunks: {json_buffer}" ) continue # Continue to the next loop to get the next chunk. except StreamedAPIError: # Re-raise to be caught by the outer retry handler. raise except Exception as buffer_exc: # If the error was not a JSONDecodeError, it's an unexpected internal error. lib_logger.error( f"Error during stream buffering logic: {buffer_exc}. Discarding buffer." ) json_buffer = ( "" # Clear the corrupted buffer to prevent further issues. ) raise buffer_exc except StreamedAPIError: # This is caught by the acompletion retry logic. # We re-raise it to ensure it's not caught by the generic 'except Exception'. raise except Exception as e: # Catch any other unexpected errors during streaming. lib_logger.error(f"Caught unexpected exception of type: {type(e).__name__}") lib_logger.error( f"An unexpected error occurred during the stream for credential {mask_credential(key)}: {e}" ) # We still need to raise it so the client knows something went wrong. raise finally: # This block now runs regardless of how the stream terminates (completion, client disconnect, etc.). # The primary goal is to ensure usage is always logged internally. await self.usage_manager.release_key(key, model) lib_logger.info( f"STREAM FINISHED and lock released for credential {mask_credential(key)}." ) # Only send [DONE] if the stream completed naturally and the client is still there. # This prevents sending [DONE] to a disconnected client or after an error. if stream_completed and ( not request or not await request.is_disconnected() ): yield "data: [DONE]\n\n" async def _execute_with_retry( self, api_call: callable, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs, ) -> Any: """A generic retry mechanism for non-streaming API calls.""" model = kwargs.get("model") if not model: raise ValueError("'model' is a required parameter.") provider = model.split("/")[0] if provider not in self.all_credentials: raise ValueError( f"No API keys or OAuth credentials configured for provider: {provider}" ) # Establish a global deadline for the entire request lifecycle. deadline = time.time() + self.global_timeout # Create a mutable copy of the keys and shuffle it to ensure # that the key selection is randomized, which is crucial when # multiple keys have the same usage stats. credentials_for_provider = list(self.all_credentials[provider]) random.shuffle(credentials_for_provider) # Filter out credentials that are unavailable (queued for re-auth) provider_plugin = self._get_provider_instance(provider) if provider_plugin and hasattr(provider_plugin, "is_credential_available"): available_creds = [ cred for cred in credentials_for_provider if provider_plugin.is_credential_available(cred) ] if available_creds: credentials_for_provider = available_creds # If all credentials are unavailable, keep the original list # (better to try unavailable creds than fail immediately) tried_creds = set() last_exception = None kwargs = self._convert_model_params(**kwargs) # The main rotation loop. It continues as long as there are untried credentials and the global deadline has not been exceeded. # Resolve model ID early, before any credential operations # This ensures consistent model ID usage for acquisition, release, and tracking resolved_model = self._resolve_model_id(model, provider) if resolved_model != model: lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") model = resolved_model kwargs["model"] = model # Ensure kwargs has the resolved model for litellm # [NEW] Filter by model tier requirement and build priority map credential_priorities = None if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): required_tier = provider_plugin.get_model_tier_requirement(model) if required_tier is not None: # Filter OUT only credentials we KNOW are too low priority # Keep credentials with unknown priority (None) - they might be high priority incompatible_creds = [] compatible_creds = [] unknown_creds = [] for cred in credentials_for_provider: if hasattr(provider_plugin, "get_credential_priority"): priority = provider_plugin.get_credential_priority(cred) if priority is None: # Unknown priority - keep it, will be discovered on first use unknown_creds.append(cred) elif priority <= required_tier: # Known compatible priority compatible_creds.append(cred) else: # Known incompatible priority (too low) incompatible_creds.append(cred) else: # Provider doesn't support priorities - keep all unknown_creds.append(cred) # If we have any known-compatible or unknown credentials, use them tier_compatible_creds = compatible_creds + unknown_creds if tier_compatible_creds: credentials_for_provider = tier_compatible_creds if compatible_creds and unknown_creds: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." ) elif compatible_creds: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(compatible_creds)} known-compatible credentials." ) else: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." ) elif incompatible_creds: # Only known-incompatible credentials remain lib_logger.warning( f"Model {model} requires priority <= {required_tier} credentials, " f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " f"Request will likely fail." ) # Build priority map and tier names map for usage_manager credential_tier_names = None if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): credential_priorities = {} credential_tier_names = {} for cred in credentials_for_provider: priority = provider_plugin.get_credential_priority(cred) if priority is not None: credential_priorities[cred] = priority # Also get tier name for logging if hasattr(provider_plugin, "get_credential_tier_name"): tier_name = provider_plugin.get_credential_tier_name(cred) if tier_name: credential_tier_names[cred] = tier_name if credential_priorities: lib_logger.debug( f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" ) # Initialize error accumulator for tracking errors across credential rotation error_accumulator = RequestErrorAccumulator() error_accumulator.model = model error_accumulator.provider = provider while ( len(tried_creds) < len(credentials_for_provider) and time.time() < deadline ): current_cred = None key_acquired = False try: # Check for a provider-wide cooldown first. if await self.cooldown_manager.is_cooling_down(provider): remaining_cooldown = ( await self.cooldown_manager.get_cooldown_remaining(provider) ) remaining_budget = deadline - time.time() # If the cooldown is longer than the remaining time budget, fail fast. if remaining_cooldown > remaining_budget: lib_logger.warning( f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." ) break lib_logger.warning( f"Provider {provider} is in cooldown. Waiting for {remaining_cooldown:.2f} seconds." ) await asyncio.sleep(remaining_cooldown) creds_to_try = [ c for c in credentials_for_provider if c not in tried_creds ] if not creds_to_try: break # Get count of credentials not on cooldown for this model available_creds = ( await self.usage_manager.get_available_credentials_for_model( creds_to_try, model ) ) available_count = len(available_creds) total_count = len(credentials_for_provider) lib_logger.info( f"Acquiring key for model {model}. Tried keys: {len(tried_creds)}/{available_count}({total_count})" ) max_concurrent = self.max_concurrent_requests_per_key.get(provider, 1) current_cred = await self.usage_manager.acquire_key( available_keys=creds_to_try, model=model, deadline=deadline, max_concurrent=max_concurrent, credential_priorities=credential_priorities, credential_tier_names=credential_tier_names, ) key_acquired = True tried_creds.add(current_cred) litellm_kwargs = self.all_providers.get_provider_kwargs(**kwargs.copy()) # [NEW] Merge provider-specific params if provider in self.litellm_provider_params: litellm_kwargs["litellm_params"] = { **self.litellm_provider_params[provider], **litellm_kwargs.get("litellm_params", {}), } provider_plugin = self._get_provider_instance(provider) # Model ID is already resolved before the loop, and kwargs['model'] is updated. # No further resolution needed here. # Apply model-specific options for custom providers if provider_plugin and hasattr(provider_plugin, "get_model_options"): model_options = provider_plugin.get_model_options(model) if model_options: # Merge model options into litellm_kwargs for key, value in model_options.items(): if key == "reasoning_effort": litellm_kwargs["reasoning_effort"] = value elif key not in litellm_kwargs: litellm_kwargs[key] = value if provider_plugin and provider_plugin.has_custom_logic(): lib_logger.debug( f"Provider '{provider}' has custom logic. Delegating call." ) litellm_kwargs["credential_identifier"] = current_cred litellm_kwargs["enable_request_logging"] = ( self.enable_request_logging ) # Check body first for custom_reasoning_budget if "custom_reasoning_budget" in kwargs: litellm_kwargs["custom_reasoning_budget"] = kwargs[ "custom_reasoning_budget" ] else: custom_budget_header = None if request and hasattr(request, "headers"): custom_budget_header = request.headers.get( "custom_reasoning_budget" ) if custom_budget_header is not None: is_budget_enabled = custom_budget_header.lower() == "true" litellm_kwargs["custom_reasoning_budget"] = ( is_budget_enabled ) # Retry loop for custom providers - mirrors streaming path error handling for attempt in range(self.max_retries): try: lib_logger.info( f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" ) if pre_request_callback: try: await pre_request_callback(request, litellm_kwargs) except Exception as e: if self.abort_on_callback_error: raise PreRequestCallbackError( f"Pre-request callback failed: {e}" ) from e else: lib_logger.warning( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) response = await provider_plugin.acompletion( self.http_client, **litellm_kwargs ) # For non-streaming, success is immediate await self.usage_manager.record_success( current_cred, model, response ) await self.usage_manager.release_key(current_cred, model) key_acquired = False return response except ( litellm.RateLimitError, httpx.HTTPStatusError, ) as e: last_exception = e classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) # Record in accumulator for client reporting error_accumulator.record_error( current_cred, classified_error, error_message ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}) during custom provider call. Failing." ) raise last_exception # Handle rate limits with cooldown (exclude quota_exceeded) if classified_error.error_type == "rate_limit": cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) await self.usage_manager.record_failure( current_cred, model, classified_error ) lib_logger.warning( f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." ) break # Rotate to next credential except ( APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError, ) as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] # Provider-level error: don't increment consecutive failures await self.usage_manager.record_failure( current_cred, model, classified_error, increment_consecutive_failures=False, ) if attempt >= self.max_retries - 1: error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." ) break wait_time = classified_error.retry_after or ( 2**attempt ) + random.uniform(0, 1) remaining_budget = deadline - time.time() if wait_time > remaining_budget: error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." ) break lib_logger.warning( f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." ) await asyncio.sleep(wait_time) continue except Exception as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] # Record in accumulator error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}). Failing." ) raise last_exception # Handle rate limits with cooldown (exclude quota_exceeded) if ( classified_error.status_code == 429 and classified_error.error_type != "quota_exceeded" ) or classified_error.error_type == "rate_limit": cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) await self.usage_manager.record_failure( current_cred, model, classified_error ) break # Rotate to next credential # If the inner loop breaks, it means the key failed and we need to rotate. # Continue to the next iteration of the outer while loop to pick a new key. continue else: # This is the standard API Key / litellm-handled provider logic is_oauth = provider in self.oauth_providers if is_oauth: # Standard OAuth provider (not custom) # ... (logic to set headers) ... pass else: # API Key litellm_kwargs["api_key"] = current_cred provider_instance = self._get_provider_instance(provider) if provider_instance: # Ensure default Gemini safety settings are present (without overriding request) try: self._apply_default_safety_settings( litellm_kwargs, provider ) except Exception: # If anything goes wrong here, avoid breaking the request flow. lib_logger.debug( "Could not apply default safety settings; continuing." ) if "safety_settings" in litellm_kwargs: converted_settings = ( provider_instance.convert_safety_settings( litellm_kwargs["safety_settings"] ) ) if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings else: del litellm_kwargs["safety_settings"] if provider == "gemini" and provider_instance: provider_instance.handle_thinking_parameter( litellm_kwargs, model ) if provider == "nvidia_nim" and provider_instance: provider_instance.handle_thinking_parameter( litellm_kwargs, model ) if "gemma-3" in model and "messages" in litellm_kwargs: litellm_kwargs["messages"] = [ {"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"] ] litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) for attempt in range(self.max_retries): try: lib_logger.info( f"Attempting call with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" ) if pre_request_callback: try: await pre_request_callback(request, litellm_kwargs) except Exception as e: if self.abort_on_callback_error: raise PreRequestCallbackError( f"Pre-request callback failed: {e}" ) from e else: lib_logger.warning( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) # Convert model parameters for custom providers right before LiteLLM call final_kwargs = self._convert_model_params_for_litellm( **litellm_kwargs ) response = await api_call( **final_kwargs, logger_fn=self._litellm_logger_callback, ) await self.usage_manager.record_success( current_cred, model, response ) await self.usage_manager.release_key(current_cred, model) key_acquired = False return response except litellm.RateLimitError as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) # Extract a clean error message for the user-facing log error_message = str(e).split("\n")[0] # Record in accumulator for client reporting error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.info( f"Key {mask_credential(current_cred)} hit rate limit for {model}. Rotating key." ) # Only trigger provider-wide cooldown for rate limits, not quota issues if ( classified_error.status_code == 429 and classified_error.error_type != "quota_exceeded" ): cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) await self.usage_manager.record_failure( current_cred, model, classified_error ) break # Move to the next key except ( APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError, ) as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] # Provider-level error: don't increment consecutive failures await self.usage_manager.record_failure( current_cred, model, classified_error, increment_consecutive_failures=False, ) if attempt >= self.max_retries - 1: # Record in accumulator only on final failure for this key error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Key {mask_credential(current_cred)} failed after max retries due to server error. Rotating." ) break # Move to the next key # For temporary errors, wait before retrying with the same key. wait_time = classified_error.retry_after or ( 2**attempt ) + random.uniform(0, 1) remaining_budget = deadline - time.time() # If the required wait time exceeds the budget, don't wait; rotate to the next key immediately. if wait_time > remaining_budget: error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Retry wait ({wait_time:.2f}s) exceeds budget ({remaining_budget:.2f}s). Rotating key." ) break lib_logger.warning( f"Key {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." ) await asyncio.sleep(wait_time) continue # Retry with the same key except httpx.HTTPStatusError as e: # Handle HTTP errors from httpx (e.g., from custom providers like Antigravity) last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] lib_logger.warning( f"Key {mask_credential(current_cred)} HTTP {e.response.status_code} ({classified_error.error_type})." ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}). Failing request." ) raise last_exception # Record in accumulator after confirming it's a rotatable error error_accumulator.record_error( current_cred, classified_error, error_message ) # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) if classified_error.error_type == "rate_limit": cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) # Check if we should retry same key (server errors with retries left) if ( should_retry_same_key(classified_error) and attempt < self.max_retries - 1 ): wait_time = classified_error.retry_after or ( 2**attempt ) + random.uniform(0, 1) remaining_budget = deadline - time.time() if wait_time <= remaining_budget: lib_logger.warning( f"Server error, retrying same key in {wait_time:.2f}s." ) await asyncio.sleep(wait_time) continue # Record failure and rotate to next key await self.usage_manager.record_failure( current_cred, model, classified_error ) lib_logger.info( f"Rotating to next key after {classified_error.error_type} error." ) break except Exception as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) if request and await request.is_disconnected(): lib_logger.warning( f"Client disconnected. Aborting retries for {mask_credential(current_cred)}." ) raise last_exception classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] lib_logger.warning( f"Key {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." ) # Handle rate limits with cooldown (exclude quota_exceeded from provider-wide cooldown) if ( classified_error.status_code == 429 and classified_error.error_type != "quota_exceeded" ) or classified_error.error_type == "rate_limit": cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}). Failing request." ) raise last_exception # Record in accumulator after confirming it's a rotatable error error_accumulator.record_error( current_cred, classified_error, error_message ) await self.usage_manager.record_failure( current_cred, model, classified_error ) break # Try next key for other errors finally: if key_acquired and current_cred: await self.usage_manager.release_key(current_cred, model) # Check if we exhausted all credentials or timed out if time.time() >= deadline: error_accumulator.timeout_occurred = True if error_accumulator.has_errors(): # Log concise summary for server logs lib_logger.error(error_accumulator.build_log_message()) # Return the structured error response for the client return error_accumulator.build_client_error_response() # Return None to indicate failure without error details (shouldn't normally happen) lib_logger.warning( "Unexpected state: request failed with no recorded errors. " "This may indicate a logic error in error tracking." ) return None async def _streaming_acompletion_with_retry( self, request: Optional[Any], pre_request_callback: Optional[callable] = None, **kwargs, ) -> AsyncGenerator[str, None]: """A dedicated generator for retrying streaming completions with full request preparation and per-key retries.""" model = kwargs.get("model") provider = model.split("/")[0] # Create a mutable copy of the keys and shuffle it. credentials_for_provider = list(self.all_credentials[provider]) random.shuffle(credentials_for_provider) # Filter out credentials that are unavailable (queued for re-auth) provider_plugin = self._get_provider_instance(provider) if provider_plugin and hasattr(provider_plugin, "is_credential_available"): available_creds = [ cred for cred in credentials_for_provider if provider_plugin.is_credential_available(cred) ] if available_creds: credentials_for_provider = available_creds # If all credentials are unavailable, keep the original list # (better to try unavailable creds than fail immediately) deadline = time.time() + self.global_timeout tried_creds = set() last_exception = None kwargs = self._convert_model_params(**kwargs) consecutive_quota_failures = 0 # Resolve model ID early, before any credential operations # This ensures consistent model ID usage for acquisition, release, and tracking resolved_model = self._resolve_model_id(model, provider) if resolved_model != model: lib_logger.info(f"Resolved model '{model}' to '{resolved_model}'") model = resolved_model kwargs["model"] = model # Ensure kwargs has the resolved model for litellm # [NEW] Filter by model tier requirement and build priority map credential_priorities = None if provider_plugin and hasattr(provider_plugin, "get_model_tier_requirement"): required_tier = provider_plugin.get_model_tier_requirement(model) if required_tier is not None: # Filter OUT only credentials we KNOW are too low priority # Keep credentials with unknown priority (None) - they might be high priority incompatible_creds = [] compatible_creds = [] unknown_creds = [] for cred in credentials_for_provider: if hasattr(provider_plugin, "get_credential_priority"): priority = provider_plugin.get_credential_priority(cred) if priority is None: # Unknown priority - keep it, will be discovered on first use unknown_creds.append(cred) elif priority <= required_tier: # Known compatible priority compatible_creds.append(cred) else: # Known incompatible priority (too low) incompatible_creds.append(cred) else: # Provider doesn't support priorities - keep all unknown_creds.append(cred) # If we have any known-compatible or unknown credentials, use them tier_compatible_creds = compatible_creds + unknown_creds if tier_compatible_creds: credentials_for_provider = tier_compatible_creds if compatible_creds and unknown_creds: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(compatible_creds)} known-compatible + {len(unknown_creds)} unknown-tier credentials." ) elif compatible_creds: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(compatible_creds)} known-compatible credentials." ) else: lib_logger.info( f"Model {model} requires priority <= {required_tier}. " f"Using {len(unknown_creds)} unknown-tier credentials (will discover on use)." ) elif incompatible_creds: # Only known-incompatible credentials remain lib_logger.warning( f"Model {model} requires priority <= {required_tier} credentials, " f"but all {len(incompatible_creds)} known credentials have priority > {required_tier}. " f"Request will likely fail." ) # Build priority map and tier names map for usage_manager credential_tier_names = None if provider_plugin and hasattr(provider_plugin, "get_credential_priority"): credential_priorities = {} credential_tier_names = {} for cred in credentials_for_provider: priority = provider_plugin.get_credential_priority(cred) if priority is not None: credential_priorities[cred] = priority # Also get tier name for logging if hasattr(provider_plugin, "get_credential_tier_name"): tier_name = provider_plugin.get_credential_tier_name(cred) if tier_name: credential_tier_names[cred] = tier_name if credential_priorities: lib_logger.debug( f"Credential priorities for {provider}: {', '.join(f'P{p}={len([c for c in credentials_for_provider if credential_priorities.get(c) == p])}' for p in sorted(set(credential_priorities.values())))}" ) # Initialize error accumulator for tracking errors across credential rotation error_accumulator = RequestErrorAccumulator() error_accumulator.model = model error_accumulator.provider = provider try: while ( len(tried_creds) < len(credentials_for_provider) and time.time() < deadline ): current_cred = None key_acquired = False try: if await self.cooldown_manager.is_cooling_down(provider): remaining_cooldown = ( await self.cooldown_manager.get_cooldown_remaining(provider) ) remaining_budget = deadline - time.time() if remaining_cooldown > remaining_budget: lib_logger.warning( f"Provider {provider} cooldown ({remaining_cooldown:.2f}s) exceeds remaining request budget ({remaining_budget:.2f}s). Failing early." ) break lib_logger.warning( f"Provider {provider} is in a global cooldown. All requests to this provider will be paused for {remaining_cooldown:.2f} seconds." ) await asyncio.sleep(remaining_cooldown) creds_to_try = [ c for c in credentials_for_provider if c not in tried_creds ] if not creds_to_try: lib_logger.warning( f"All credentials for provider {provider} have been tried. No more credentials to rotate to." ) break # Get count of credentials not on cooldown for this model available_creds = ( await self.usage_manager.get_available_credentials_for_model( creds_to_try, model ) ) available_count = len(available_creds) total_count = len(credentials_for_provider) lib_logger.info( f"Acquiring credential for model {model}. Tried credentials: {len(tried_creds)}/{available_count}({total_count})" ) max_concurrent = self.max_concurrent_requests_per_key.get( provider, 1 ) current_cred = await self.usage_manager.acquire_key( available_keys=creds_to_try, model=model, deadline=deadline, max_concurrent=max_concurrent, credential_priorities=credential_priorities, credential_tier_names=credential_tier_names, ) key_acquired = True tried_creds.add(current_cred) litellm_kwargs = self.all_providers.get_provider_kwargs( **kwargs.copy() ) if "reasoning_effort" in kwargs: litellm_kwargs["reasoning_effort"] = kwargs["reasoning_effort"] # Check body first for custom_reasoning_budget if "custom_reasoning_budget" in kwargs: litellm_kwargs["custom_reasoning_budget"] = kwargs[ "custom_reasoning_budget" ] else: custom_budget_header = None if request and hasattr(request, "headers"): custom_budget_header = request.headers.get( "custom_reasoning_budget" ) if custom_budget_header is not None: is_budget_enabled = custom_budget_header.lower() == "true" litellm_kwargs["custom_reasoning_budget"] = ( is_budget_enabled ) # [NEW] Merge provider-specific params if provider in self.litellm_provider_params: litellm_kwargs["litellm_params"] = { **self.litellm_provider_params[provider], **litellm_kwargs.get("litellm_params", {}), } provider_plugin = self._get_provider_instance(provider) # Model ID is already resolved before the loop, and kwargs['model'] is updated. # No further resolution needed here. # Apply model-specific options for custom providers if provider_plugin and hasattr( provider_plugin, "get_model_options" ): model_options = provider_plugin.get_model_options(model) if model_options: # Merge model options into litellm_kwargs for key, value in model_options.items(): if key == "reasoning_effort": litellm_kwargs["reasoning_effort"] = value elif key not in litellm_kwargs: litellm_kwargs[key] = value if provider_plugin and provider_plugin.has_custom_logic(): lib_logger.debug( f"Provider '{provider}' has custom logic. Delegating call." ) litellm_kwargs["credential_identifier"] = current_cred litellm_kwargs["enable_request_logging"] = ( self.enable_request_logging ) for attempt in range(self.max_retries): try: lib_logger.info( f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" ) if pre_request_callback: try: await pre_request_callback( request, litellm_kwargs ) except Exception as e: if self.abort_on_callback_error: raise PreRequestCallbackError( f"Pre-request callback failed: {e}" ) from e else: lib_logger.warning( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) response = await provider_plugin.acompletion( self.http_client, **litellm_kwargs ) lib_logger.info( f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." ) key_acquired = False stream_generator = self._safe_streaming_wrapper( response, current_cred, model, request, provider_plugin, ) async for chunk in stream_generator: yield chunk return except ( StreamedAPIError, litellm.RateLimitError, httpx.HTTPStatusError, ) as e: last_exception = e # If the exception is our custom wrapper, unwrap the original error original_exc = getattr(e, "data", e) classified_error = classify_error( original_exc, provider=provider ) error_message = str(original_exc).split("\n")[0] log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) # Record in accumulator for client reporting error_accumulator.record_error( current_cred, classified_error, error_message ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}) during custom stream. Failing." ) raise last_exception # Handle rate limits with cooldown (exclude quota_exceeded) if classified_error.error_type == "rate_limit": cooldown_duration = ( classified_error.retry_after or 60 ) await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) await self.usage_manager.record_failure( current_cred, model, classified_error ) lib_logger.warning( f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code}). Rotating." ) break except ( APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError, ) as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] # Provider-level error: don't increment consecutive failures await self.usage_manager.record_failure( current_cred, model, classified_error, increment_consecutive_failures=False, ) if attempt >= self.max_retries - 1: error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Cred {mask_credential(current_cred)} failed after max retries. Rotating." ) break wait_time = classified_error.retry_after or ( 2**attempt ) + random.uniform(0, 1) remaining_budget = deadline - time.time() if wait_time > remaining_budget: error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Retry wait ({wait_time:.2f}s) exceeds budget. Rotating." ) break lib_logger.warning( f"Cred {mask_credential(current_cred)} server error. Retrying in {wait_time:.2f}s." ) await asyncio.sleep(wait_time) continue except Exception as e: last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message = str(e).split("\n")[0] # Record in accumulator error_accumulator.record_error( current_cred, classified_error, error_message ) lib_logger.warning( f"Cred {mask_credential(current_cred)} {classified_error.error_type} (HTTP {classified_error.status_code})." ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}). Failing." ) raise last_exception await self.usage_manager.record_failure( current_cred, model, classified_error ) break # If the inner loop breaks, it means the key failed and we need to rotate. # Continue to the next iteration of the outer while loop to pick a new key. continue else: # This is the standard API Key / litellm-handled provider logic is_oauth = provider in self.oauth_providers if is_oauth: # Standard OAuth provider (not custom) # ... (logic to set headers) ... pass else: # API Key litellm_kwargs["api_key"] = current_cred provider_instance = self._get_provider_instance(provider) if provider_instance: # Ensure default Gemini safety settings are present (without overriding request) try: self._apply_default_safety_settings( litellm_kwargs, provider ) except Exception: lib_logger.debug( "Could not apply default safety settings for streaming path; continuing." ) if "safety_settings" in litellm_kwargs: converted_settings = ( provider_instance.convert_safety_settings( litellm_kwargs["safety_settings"] ) ) if converted_settings is not None: litellm_kwargs["safety_settings"] = converted_settings else: del litellm_kwargs["safety_settings"] if provider == "gemini" and provider_instance: provider_instance.handle_thinking_parameter( litellm_kwargs, model ) if provider == "nvidia_nim" and provider_instance: provider_instance.handle_thinking_parameter( litellm_kwargs, model ) if "gemma-3" in model and "messages" in litellm_kwargs: litellm_kwargs["messages"] = [ {"role": "user", "content": m["content"]} if m.get("role") == "system" else m for m in litellm_kwargs["messages"] ] litellm_kwargs = sanitize_request_payload(litellm_kwargs, model) # If the provider is 'qwen_code', set the custom provider to 'qwen' # and strip the prefix from the model name for LiteLLM. if provider == "qwen_code": litellm_kwargs["custom_llm_provider"] = "qwen" litellm_kwargs["model"] = model.split("/", 1)[1] for attempt in range(self.max_retries): try: lib_logger.info( f"Attempting stream with credential {mask_credential(current_cred)} (Attempt {attempt + 1}/{self.max_retries})" ) if pre_request_callback: try: await pre_request_callback(request, litellm_kwargs) except Exception as e: if self.abort_on_callback_error: raise PreRequestCallbackError( f"Pre-request callback failed: {e}" ) from e else: lib_logger.warning( f"Pre-request callback failed but abort_on_callback_error is False. Proceeding with request. Error: {e}" ) # lib_logger.info(f"DEBUG: litellm.acompletion kwargs: {litellm_kwargs}") # Convert model parameters for custom providers right before LiteLLM call final_kwargs = self._convert_model_params_for_litellm( **litellm_kwargs ) response = await litellm.acompletion( **final_kwargs, logger_fn=self._litellm_logger_callback, ) lib_logger.info( f"Stream connection established for credential {mask_credential(current_cred)}. Processing response." ) key_acquired = False stream_generator = self._safe_streaming_wrapper( response, current_cred, model, request, provider_instance, ) async for chunk in stream_generator: yield chunk return except ( StreamedAPIError, litellm.RateLimitError, httpx.HTTPStatusError, ) as e: last_exception = e # This is the final, robust handler for streamed errors. error_payload = {} cleaned_str = None # The actual exception might be wrapped in our StreamedAPIError. original_exc = getattr(e, "data", e) classified_error = classify_error( original_exc, provider=provider ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): lib_logger.error( f"Non-recoverable error ({classified_error.error_type}) during litellm stream. Failing." ) raise last_exception try: # The full error JSON is in the string representation of the exception. json_str_match = re.search( r"(\{.*\})", str(original_exc), re.DOTALL ) if json_str_match: cleaned_str = codecs.decode( json_str_match.group(1), "unicode_escape" ) error_payload = json.loads(cleaned_str) except (json.JSONDecodeError, TypeError): error_payload = {} log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, raw_response_text=cleaned_str, ) error_details = error_payload.get("error", {}) error_status = error_details.get("status", "") error_message_text = error_details.get( "message", str(original_exc).split("\n")[0] ) # Record in accumulator for client reporting error_accumulator.record_error( current_cred, classified_error, error_message_text ) if ( "quota" in error_message_text.lower() or "resource_exhausted" in error_status.lower() ): consecutive_quota_failures += 1 quota_value = "N/A" quota_id = "N/A" if "details" in error_details and isinstance( error_details.get("details"), list ): for detail in error_details["details"]: if isinstance(detail.get("violations"), list): for violation in detail["violations"]: if "quotaValue" in violation: quota_value = violation[ "quotaValue" ] if "quotaId" in violation: quota_id = violation["quotaId"] if ( quota_value != "N/A" and quota_id != "N/A" ): break await self.usage_manager.record_failure( current_cred, model, classified_error ) if consecutive_quota_failures >= 3: # Fatal: likely input data too large client_error_message = ( f"Request failed after 3 consecutive quota errors (input may be too large). " f"Limit: {quota_value} (Quota ID: {quota_id})" ) lib_logger.error( f"Fatal quota error for {mask_credential(current_cred)}. ID: {quota_id}, Limit: {quota_value}" ) yield f"data: {json.dumps({'error': {'message': client_error_message, 'type': 'proxy_fatal_quota_error'}})}\n\n" yield "data: [DONE]\n\n" return else: lib_logger.warning( f"Cred {mask_credential(current_cred)} quota error ({consecutive_quota_failures}/3). Rotating." ) break else: consecutive_quota_failures = 0 lib_logger.warning( f"Cred {mask_credential(current_cred)} {classified_error.error_type}. Rotating." ) if classified_error.error_type == "rate_limit": cooldown_duration = ( classified_error.retry_after or 60 ) await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) await self.usage_manager.record_failure( current_cred, model, classified_error ) break except ( APIConnectionError, litellm.InternalServerError, litellm.ServiceUnavailableError, ) as e: consecutive_quota_failures = 0 last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message_text = str(e).split("\n")[0] # Record error in accumulator (server errors are transient, not abnormal) error_accumulator.record_error( current_cred, classified_error, error_message_text ) # Provider-level error: don't increment consecutive failures await self.usage_manager.record_failure( current_cred, model, classified_error, increment_consecutive_failures=False, ) if attempt >= self.max_retries - 1: lib_logger.warning( f"Credential {mask_credential(current_cred)} failed after max retries for model {model} due to a server error. Rotating key silently." ) # [MODIFIED] Do not yield to the client here. break wait_time = classified_error.retry_after or ( 2**attempt ) + random.uniform(0, 1) remaining_budget = deadline - time.time() if wait_time > remaining_budget: lib_logger.warning( f"Required retry wait time ({wait_time:.2f}s) exceeds remaining budget ({remaining_budget:.2f}s). Rotating key early." ) break lib_logger.warning( f"Credential {mask_credential(current_cred)} encountered a server error for model {model}. Reason: '{error_message_text}'. Retrying in {wait_time:.2f}s." ) await asyncio.sleep(wait_time) continue except Exception as e: consecutive_quota_failures = 0 last_exception = e log_failure( api_key=current_cred, model=model, attempt=attempt + 1, error=e, request_headers=dict(request.headers) if request else {}, ) classified_error = classify_error(e, provider=provider) error_message_text = str(e).split("\n")[0] # Record error in accumulator error_accumulator.record_error( current_cred, classified_error, error_message_text ) lib_logger.warning( f"Credential {mask_credential(current_cred)} failed with {classified_error.error_type} (Status: {classified_error.status_code}). Error: {error_message_text}." ) # Handle rate limits with cooldown (exclude quota_exceeded) if ( classified_error.status_code == 429 and classified_error.error_type != "quota_exceeded" ) or classified_error.error_type == "rate_limit": cooldown_duration = classified_error.retry_after or 60 await self.cooldown_manager.start_cooldown( provider, cooldown_duration ) lib_logger.warning( f"Rate limit detected for {provider}. Starting {cooldown_duration}s cooldown." ) # Check if this error should trigger rotation if not should_rotate_on_error(classified_error): # Non-rotatable errors - fail immediately lib_logger.error( f"Non-recoverable error ({classified_error.error_type}). Failing request." ) raise last_exception # Record failure and rotate to next key await self.usage_manager.record_failure( current_cred, model, classified_error ) lib_logger.info( f"Rotating to next key after {classified_error.error_type} error." ) break finally: if key_acquired and current_cred: await self.usage_manager.release_key(current_cred, model) # Build detailed error response using error accumulator error_accumulator.timeout_occurred = time.time() >= deadline if error_accumulator.has_errors(): # Log concise summary for server logs lib_logger.error(error_accumulator.build_log_message()) # Build structured error response for client error_response = error_accumulator.build_client_error_response() error_data = error_response else: # Fallback if no errors were recorded (shouldn't happen) final_error_message = ( "Request failed: No available API keys after rotation or timeout." ) if last_exception: final_error_message = ( f"Request failed. Last error: {str(last_exception)}" ) error_data = { "error": {"message": final_error_message, "type": "proxy_error"} } lib_logger.error(final_error_message) yield f"data: {json.dumps(error_data)}\n\n" yield "data: [DONE]\n\n" except NoAvailableKeysError as e: lib_logger.error( f"A streaming request failed because no keys were available within the time budget: {e}" ) error_data = {"error": {"message": str(e), "type": "proxy_busy"}} yield f"data: {json.dumps(error_data)}\n\n" yield "data: [DONE]\n\n" except Exception as e: # This will now only catch fatal errors that should be raised, like invalid requests. lib_logger.error( f"An unhandled exception occurred in streaming retry logic: {e}", exc_info=True, ) error_data = { "error": { "message": f"An unexpected error occurred: {str(e)}", "type": "proxy_internal_error", } } yield f"data: {json.dumps(error_data)}\n\n" yield "data: [DONE]\n\n" def acompletion( self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs, ) -> Union[Any, AsyncGenerator[str, None]]: """ Dispatcher for completion requests. Args: request: Optional request object, used for client disconnect checks and logging. pre_request_callback: Optional async callback function to be called before each API request attempt. The callback will receive the `request` object and the prepared request `kwargs` as arguments. This can be used for custom logic such as request validation, logging, or rate limiting. If the callback raises an exception, the completion request will be aborted and the exception will propagate. Returns: The completion response object, or an async generator for streaming responses, or None if all retries fail. """ # Handle iflow provider: remove stream_options to avoid HTTP 406 model = kwargs.get("model", "") provider = model.split("/")[0] if "/" in model else "" if provider == "iflow" and "stream_options" in kwargs: lib_logger.debug( "Removing stream_options for iflow provider to avoid HTTP 406" ) kwargs.pop("stream_options", None) if kwargs.get("stream"): # Only add stream_options for providers that support it (excluding iflow) if provider != "iflow": if "stream_options" not in kwargs: kwargs["stream_options"] = {} if "include_usage" not in kwargs["stream_options"]: kwargs["stream_options"]["include_usage"] = True return self._streaming_acompletion_with_retry( request=request, pre_request_callback=pre_request_callback, **kwargs ) else: return self._execute_with_retry( litellm.acompletion, request=request, pre_request_callback=pre_request_callback, **kwargs, ) def aembedding( self, request: Optional[Any] = None, pre_request_callback: Optional[callable] = None, **kwargs, ) -> Any: """ Executes an embedding request with retry logic. Args: request: Optional request object, used for client disconnect checks and logging. pre_request_callback: Optional async callback function to be called before each API request attempt. The callback will receive the `request` object and the prepared request `kwargs` as arguments. This can be used for custom logic such as request validation, logging, or rate limiting. If the callback raises an exception, the embedding request will be aborted and the exception will propagate. Returns: The embedding response object, or None if all retries fail. """ return self._execute_with_retry( litellm.aembedding, request=request, pre_request_callback=pre_request_callback, **kwargs, ) def token_count(self, **kwargs) -> int: """Calculates the number of tokens for a given text or list of messages.""" kwargs = self._convert_model_params(**kwargs) model = kwargs.get("model") text = kwargs.get("text") messages = kwargs.get("messages") if not model: raise ValueError("'model' is a required parameter.") if messages: return token_counter(model=model, messages=messages) elif text: return token_counter(model=model, text=text) else: raise ValueError("Either 'text' or 'messages' must be provided.") async def get_available_models(self, provider: str) -> List[str]: """Returns a list of available models for a specific provider, with caching.""" lib_logger.info(f"Getting available models for provider: {provider}") if provider in self._model_list_cache: lib_logger.debug(f"Returning cached models for provider: {provider}") return self._model_list_cache[provider] credentials_for_provider = self.all_credentials.get(provider) if not credentials_for_provider: lib_logger.warning(f"No credentials for provider: {provider}") return [] # Create a copy and shuffle it to randomize the starting credential shuffled_credentials = list(credentials_for_provider) random.shuffle(shuffled_credentials) provider_instance = self._get_provider_instance(provider) if provider_instance: # For providers with hardcoded models (like gemini_cli), we only need to call once. # For others, we might need to try multiple keys if one is invalid. # The current logic of iterating works for both, as the credential is not # always used in get_models. for credential in shuffled_credentials: try: # Display last 6 chars for API keys, or the filename for OAuth paths cred_display = mask_credential(credential) lib_logger.debug( f"Attempting to get models for {provider} with credential {cred_display}" ) models = await provider_instance.get_models( credential, self.http_client ) lib_logger.info( f"Got {len(models)} models for provider: {provider}" ) # Whitelist and blacklist logic final_models = [] for m in models: is_whitelisted = self._is_model_whitelisted(provider, m) is_blacklisted = self._is_model_ignored(provider, m) if is_whitelisted: final_models.append(m) continue if not is_blacklisted: final_models.append(m) if len(final_models) != len(models): lib_logger.info( f"Filtered out {len(models) - len(final_models)} models for provider {provider}." ) self._model_list_cache[provider] = final_models return final_models except Exception as e: classified_error = classify_error(e, provider=provider) cred_display = mask_credential(credential) lib_logger.debug( f"Failed to get models for provider {provider} with credential {cred_display}: {classified_error.error_type}. Trying next credential." ) continue # Try the next credential lib_logger.error( f"Failed to get models for provider {provider} after trying all credentials." ) return [] async def get_all_available_models( self, grouped: bool = True ) -> Union[Dict[str, List[str]], List[str]]: """Returns a list of all available models, either grouped by provider or as a flat list.""" lib_logger.info("Getting all available models...") all_providers = list(self.all_credentials.keys()) tasks = [self.get_available_models(provider) for provider in all_providers] results = await asyncio.gather(*tasks, return_exceptions=True) all_provider_models = {} for provider, result in zip(all_providers, results): if isinstance(result, Exception): lib_logger.error( f"Failed to get models for provider {provider}: {result}" ) all_provider_models[provider] = [] else: all_provider_models[provider] = result lib_logger.info("Finished getting all available models.") if grouped: return all_provider_models else: flat_models = [] for models in all_provider_models.values(): flat_models.extend(models) return flat_models async def get_quota_stats( self, provider_filter: Optional[str] = None, ) -> Dict[str, Any]: """ Get quota and usage stats for all credentials. This returns cached/disk data aggregated by provider. For provider-specific quota info (e.g., Antigravity quota groups), it enriches the data from provider plugins. Args: provider_filter: If provided, only return stats for this provider Returns: Complete stats dict ready for the /v1/quota-stats endpoint """ # Get base stats from usage manager stats = await self.usage_manager.get_stats_for_endpoint(provider_filter) # Enrich with provider-specific quota data for provider, prov_stats in stats.get("providers", {}).items(): provider_class = self._provider_plugins.get(provider) if not provider_class: continue # Get or create provider instance if provider not in self._provider_instances: self._provider_instances[provider] = provider_class() provider_instance = self._provider_instances[provider] # Check if provider has quota tracking (like Antigravity) if hasattr(provider_instance, "_get_effective_quota_groups"): # Add quota group summary quota_groups = provider_instance._get_effective_quota_groups() prov_stats["quota_groups"] = {} for group_name, group_models in quota_groups.items(): group_stats = { "models": group_models, "credentials_total": 0, "credentials_exhausted": 0, "avg_remaining_pct": 0, "total_remaining_pcts": [], # Total requests tracking across all credentials "total_requests_used": 0, "total_requests_max": 0, # Tier breakdown: tier_name -> {"total": N, "active": M} "tiers": {}, } # Calculate per-credential quota for this group for cred in prov_stats.get("credentials", []): models_data = cred.get("models", {}) group_stats["credentials_total"] += 1 # Track tier - get directly from provider cache since cred["tier"] not set yet tier = cred.get("tier") if not tier and hasattr( provider_instance, "project_tier_cache" ): cred_path = cred.get("full_path", "") tier = provider_instance.project_tier_cache.get(cred_path) tier = tier or "unknown" # Initialize tier entry if needed with priority for sorting if tier not in group_stats["tiers"]: priority = 10 # default if hasattr(provider_instance, "_resolve_tier_priority"): priority = provider_instance._resolve_tier_priority( tier ) group_stats["tiers"][tier] = { "total": 0, "active": 0, "priority": priority, } group_stats["tiers"][tier]["total"] += 1 # Find model with VALID baseline (not just any model with stats) model_stats = None for model in group_models: candidate = self._find_model_stats_in_data( models_data, model, provider, provider_instance ) if candidate: baseline = candidate.get("baseline_remaining_fraction") if baseline is not None: model_stats = candidate break # Keep first found as fallback (for request counts) if model_stats is None: model_stats = candidate if model_stats: baseline = model_stats.get("baseline_remaining_fraction") req_count = model_stats.get("request_count", 0) max_req = model_stats.get("quota_max_requests") or 0 # Accumulate totals (one model per group per credential) group_stats["total_requests_used"] += req_count group_stats["total_requests_max"] += max_req if baseline is not None: remaining_pct = int(baseline * 100) group_stats["total_remaining_pcts"].append( remaining_pct ) if baseline <= 0: group_stats["credentials_exhausted"] += 1 else: # Credential is active (has quota remaining) group_stats["tiers"][tier]["active"] += 1 # Calculate average remaining percentage (per-credential average) if group_stats["total_remaining_pcts"]: group_stats["avg_remaining_pct"] = int( sum(group_stats["total_remaining_pcts"]) / len(group_stats["total_remaining_pcts"]) ) del group_stats["total_remaining_pcts"] # Calculate total remaining percentage (global) if group_stats["total_requests_max"] > 0: used = group_stats["total_requests_used"] max_r = group_stats["total_requests_max"] group_stats["total_remaining_pct"] = max( 0, int((1 - used / max_r) * 100) ) else: group_stats["total_remaining_pct"] = None prov_stats["quota_groups"][group_name] = group_stats # Also enrich each credential with formatted quota group info for cred in prov_stats.get("credentials", []): cred["model_groups"] = {} models_data = cred.get("models", {}) for group_name, group_models in quota_groups.items(): # Find model with VALID baseline (prefer over any model with stats) # Also track the best reset_ts across all models in the group model_stats = None best_reset_ts = None for model in group_models: candidate = self._find_model_stats_in_data( models_data, model, provider, provider_instance ) if candidate: # Track the best (latest) reset_ts from any model in group candidate_reset_ts = candidate.get("quota_reset_ts") if candidate_reset_ts: if ( best_reset_ts is None or candidate_reset_ts > best_reset_ts ): best_reset_ts = candidate_reset_ts baseline = candidate.get("baseline_remaining_fraction") if baseline is not None: model_stats = candidate # Don't break - continue to find best reset_ts # Keep first found as fallback if model_stats is None: model_stats = candidate if model_stats: baseline = model_stats.get("baseline_remaining_fraction") max_req = model_stats.get("quota_max_requests") req_count = model_stats.get("request_count", 0) # Use best_reset_ts from any model in the group reset_ts = best_reset_ts or model_stats.get( "quota_reset_ts" ) remaining_pct = ( int(baseline * 100) if baseline is not None else None ) is_exhausted = baseline is not None and baseline <= 0 # Format reset time reset_iso = None if reset_ts: try: from datetime import datetime, timezone reset_iso = datetime.fromtimestamp( reset_ts, tz=timezone.utc ).isoformat() except (ValueError, OSError): pass cred["model_groups"][group_name] = { "remaining_pct": remaining_pct, "requests_used": req_count, "requests_max": max_req, "display": f"{req_count}/{max_req}" if max_req else f"{req_count}/?", "is_exhausted": is_exhausted, "reset_time_iso": reset_iso, "models": group_models, "confidence": self._get_baseline_confidence( model_stats ), } # Recalculate credential's requests from model_groups # This fixes double-counting when models share quota groups if cred.get("model_groups"): group_requests = sum( g.get("requests_used", 0) for g in cred["model_groups"].values() ) cred["requests"] = group_requests # HACK: Fix global requests if present # This is a simplified fix that sets global.requests = current group_requests. # TODO: Properly track archived requests per quota group in usage_manager.py # so that global stats correctly sum: current_period + archived_periods # without double-counting models that share quota groups. # See: usage_manager.py lines 2388-2404 where global stats are built # by iterating all models (causing double-counting for grouped models). if cred.get("global"): cred["global"]["requests"] = group_requests # Try to get email from provider's cache cred_path = cred.get("full_path", "") if hasattr(provider_instance, "project_tier_cache"): tier = provider_instance.project_tier_cache.get(cred_path) if tier: cred["tier"] = tier return stats def _find_model_stats_in_data( self, models_data: Dict[str, Any], model: str, provider: str, provider_instance: Any, ) -> Optional[Dict[str, Any]]: """ Find model stats in models_data, trying various name variants. Handles aliased model names (e.g., gemini-3-pro-preview -> gemini-3-pro-high) by using the provider's _user_to_api_model() mapping. Args: models_data: Dict of model_name -> stats from credential model: Model name to look up (user-facing name) provider: Provider name for prefixing provider_instance: Provider instance for alias methods Returns: Model stats dict if found, None otherwise """ # Try direct match with and without provider prefix prefixed_model = f"{provider}/{model}" model_stats = models_data.get(prefixed_model) or models_data.get(model) if model_stats: return model_stats # Try with API model name (e.g., gemini-3-pro-preview -> gemini-3-pro-high) if hasattr(provider_instance, "_user_to_api_model"): api_model = provider_instance._user_to_api_model(model) if api_model != model: prefixed_api = f"{provider}/{api_model}" model_stats = models_data.get(prefixed_api) or models_data.get( api_model ) return model_stats def _get_baseline_confidence(self, model_stats: Dict) -> str: """ Determine confidence level based on baseline age. Args: model_stats: Model statistics dict with baseline_fetched_at Returns: "high" | "medium" | "low" """ baseline_fetched_at = model_stats.get("baseline_fetched_at") if not baseline_fetched_at: return "low" age_seconds = time.time() - baseline_fetched_at if age_seconds < 300: # 5 minutes return "high" elif age_seconds < 1800: # 30 minutes return "medium" return "low" async def reload_usage_from_disk(self) -> None: """ Force reload usage data from disk. Useful when wanting fresh stats without making external API calls. """ await self.usage_manager.reload_from_disk() async def force_refresh_quota( self, provider: Optional[str] = None, credential: Optional[str] = None, ) -> Dict[str, Any]: """ Force refresh quota from external API. For Antigravity, this fetches live quota data from the API. For other providers, this is a no-op (just reloads from disk). Args: provider: If specified, only refresh this provider credential: If specified, only refresh this specific credential Returns: Refresh result dict with success/failure info """ result = { "action": "force_refresh", "scope": "credential" if credential else ("provider" if provider else "all"), "provider": provider, "credential": credential, "credentials_refreshed": 0, "success_count": 0, "failed_count": 0, "duration_ms": 0, "errors": [], } start_time = time.time() # Determine which providers to refresh if provider: providers_to_refresh = ( [provider] if provider in self.all_credentials else [] ) else: providers_to_refresh = list(self.all_credentials.keys()) for prov in providers_to_refresh: provider_class = self._provider_plugins.get(prov) if not provider_class: continue # Get or create provider instance if prov not in self._provider_instances: self._provider_instances[prov] = provider_class() provider_instance = self._provider_instances[prov] # Check if provider supports quota refresh (like Antigravity) if hasattr(provider_instance, "fetch_initial_baselines"): # Get credentials to refresh if credential: # Find full path for this credential creds_to_refresh = [] for cred_path in self.all_credentials.get(prov, []): if cred_path.endswith(credential) or cred_path == credential: creds_to_refresh.append(cred_path) break else: creds_to_refresh = self.all_credentials.get(prov, []) if not creds_to_refresh: continue try: # Fetch live quota from API for ALL specified credentials quota_results = await provider_instance.fetch_initial_baselines( creds_to_refresh ) # Store baselines in usage manager if hasattr(provider_instance, "_store_baselines_to_usage_manager"): stored = ( await provider_instance._store_baselines_to_usage_manager( quota_results, self.usage_manager ) ) result["success_count"] += stored result["credentials_refreshed"] += len(creds_to_refresh) # Count failures for cred_path, data in quota_results.items(): if data.get("status") != "success": result["failed_count"] += 1 result["errors"].append( f"{Path(cred_path).name}: {data.get('error', 'Unknown error')}" ) except Exception as e: lib_logger.error(f"Failed to refresh quota for {prov}: {e}") result["errors"].append(f"{prov}: {str(e)}") result["failed_count"] += len(creds_to_refresh) result["duration_ms"] = int((time.time() - start_time) * 1000) return result