Spaces:
Paused
Paused
| """ | |
| Common helpers / utils across al OpenAI endpoints | |
| """ | |
| import hashlib | |
| import json | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| import httpx | |
| import openai | |
| from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI | |
| import litellm | |
| from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
| from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS | |
| class OpenAIError(BaseLLMException): | |
| def __init__( | |
| self, | |
| status_code: int, | |
| message: str, | |
| request: Optional[httpx.Request] = None, | |
| response: Optional[httpx.Response] = None, | |
| headers: Optional[Union[dict, httpx.Headers]] = None, | |
| body: Optional[dict] = None, | |
| ): | |
| self.status_code = status_code | |
| self.message = message | |
| self.headers = headers | |
| if request: | |
| self.request = request | |
| else: | |
| self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") | |
| if response: | |
| self.response = response | |
| else: | |
| self.response = httpx.Response( | |
| status_code=status_code, request=self.request | |
| ) | |
| super().__init__( | |
| status_code=status_code, | |
| message=self.message, | |
| headers=self.headers, | |
| request=self.request, | |
| response=self.response, | |
| body=body, | |
| ) | |
| ####### Error Handling Utils for OpenAI API ####################### | |
| ################################################################### | |
| def drop_params_from_unprocessable_entity_error( | |
| e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError], | |
| data: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| """ | |
| Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. | |
| Args: | |
| e (UnprocessableEntityError): The UnprocessableEntityError exception | |
| data (Dict[str, Any]): The original data dictionary containing all parameters | |
| Returns: | |
| Dict[str, Any]: A new dictionary with invalid parameters removed | |
| """ | |
| invalid_params: List[str] = [] | |
| if isinstance(e, httpx.HTTPStatusError): | |
| error_json = e.response.json() | |
| error_message = error_json.get("error", {}) | |
| error_body = error_message | |
| else: | |
| error_body = e.body | |
| if ( | |
| error_body is not None | |
| and isinstance(error_body, dict) | |
| and error_body.get("message") | |
| ): | |
| message = error_body.get("message", {}) | |
| if isinstance(message, str): | |
| try: | |
| message = json.loads(message) | |
| except json.JSONDecodeError: | |
| message = {"detail": message} | |
| detail = message.get("detail") | |
| if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): | |
| for error_dict in detail: | |
| if ( | |
| error_dict.get("loc") | |
| and isinstance(error_dict.get("loc"), list) | |
| and len(error_dict.get("loc")) == 2 | |
| ): | |
| invalid_params.append(error_dict["loc"][1]) | |
| new_data = {k: v for k, v in data.items() if k not in invalid_params} | |
| return new_data | |
| class BaseOpenAILLM: | |
| """ | |
| Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings | |
| """ | |
| def get_cached_openai_client( | |
| client_initialization_params: dict, client_type: Literal["openai", "azure"] | |
| ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]: | |
| """Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters""" | |
| _cache_key = BaseOpenAILLM.get_openai_client_cache_key( | |
| client_initialization_params=client_initialization_params, | |
| client_type=client_type, | |
| ) | |
| _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key) | |
| return _cached_client | |
| def set_cached_openai_client( | |
| openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], | |
| client_type: Literal["openai", "azure"], | |
| client_initialization_params: dict, | |
| ): | |
| """Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS""" | |
| _cache_key = BaseOpenAILLM.get_openai_client_cache_key( | |
| client_initialization_params=client_initialization_params, | |
| client_type=client_type, | |
| ) | |
| litellm.in_memory_llm_clients_cache.set_cache( | |
| key=_cache_key, | |
| value=openai_client, | |
| ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, | |
| ) | |
| def get_openai_client_cache_key( | |
| client_initialization_params: dict, client_type: Literal["openai", "azure"] | |
| ) -> str: | |
| """Creates a cache key for the OpenAI client based on the client initialization parameters""" | |
| hashed_api_key = None | |
| if client_initialization_params.get("api_key") is not None: | |
| hash_object = hashlib.sha256( | |
| client_initialization_params.get("api_key", "").encode() | |
| ) | |
| # Hexadecimal representation of the hash | |
| hashed_api_key = hash_object.hexdigest() | |
| # Create a more readable cache key using a list of key-value pairs | |
| key_parts = [ | |
| f"hashed_api_key={hashed_api_key}", | |
| f"is_async={client_initialization_params.get('is_async')}", | |
| ] | |
| LITELLM_CLIENT_SPECIFIC_PARAMS = [ | |
| "timeout", | |
| "max_retries", | |
| "organization", | |
| "api_base", | |
| ] | |
| openai_client_fields = ( | |
| BaseOpenAILLM.get_openai_client_initialization_param_fields( | |
| client_type=client_type | |
| ) | |
| + LITELLM_CLIENT_SPECIFIC_PARAMS | |
| ) | |
| for param in openai_client_fields: | |
| key_parts.append(f"{param}={client_initialization_params.get(param)}") | |
| _cache_key = ",".join(key_parts) | |
| return _cache_key | |
| def get_openai_client_initialization_param_fields( | |
| client_type: Literal["openai", "azure"] | |
| ) -> List[str]: | |
| """Returns a list of fields that are used to initialize the OpenAI client""" | |
| import inspect | |
| from openai import AzureOpenAI, OpenAI | |
| if client_type == "openai": | |
| signature = inspect.signature(OpenAI.__init__) | |
| else: | |
| signature = inspect.signature(AzureOpenAI.__init__) | |
| # Extract parameter names, excluding 'self' | |
| param_names = [param for param in signature.parameters if param != "self"] | |
| return param_names | |
| def _get_async_http_client() -> Optional[httpx.AsyncClient]: | |
| if litellm.aclient_session is not None: | |
| return litellm.aclient_session | |
| return httpx.AsyncClient( | |
| limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100), | |
| verify=litellm.ssl_verify, | |
| ) | |
| def _get_sync_http_client() -> Optional[httpx.Client]: | |
| if litellm.client_session is not None: | |
| return litellm.client_session | |
| return httpx.Client( | |
| limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100), | |
| verify=litellm.ssl_verify, | |
| ) | |