Spaces:
Paused
Paused
| from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union | |
| from httpx import Headers, Response | |
| from litellm.constants import DEFAULT_MAX_TOKENS | |
| from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import ModelResponse | |
| from ..common_utils import PredibaseError | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
| LiteLLMLoggingObj = _LiteLLMLoggingObj | |
| else: | |
| LiteLLMLoggingObj = Any | |
| class PredibaseConfig(BaseConfig): | |
| """ | |
| Reference: https://docs.predibase.com/user-guide/inference/rest_api | |
| """ | |
| adapter_id: Optional[str] = None | |
| adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None | |
| best_of: Optional[int] = None | |
| decoder_input_details: Optional[bool] = None | |
| details: bool = True # enables returning logprobs + best of | |
| max_new_tokens: int = ( | |
| DEFAULT_MAX_TOKENS # openai default - requests hang if max_new_tokens not given | |
| ) | |
| repetition_penalty: Optional[float] = None | |
| return_full_text: Optional[ | |
| bool | |
| ] = False # by default don't return the input as part of the output | |
| seed: Optional[int] = None | |
| stop: Optional[List[str]] = None | |
| temperature: Optional[float] = None | |
| top_k: Optional[int] = None | |
| top_p: Optional[int] = None | |
| truncate: Optional[int] = None | |
| typical_p: Optional[float] = None | |
| watermark: Optional[bool] = None | |
| def __init__( | |
| self, | |
| best_of: Optional[int] = None, | |
| decoder_input_details: Optional[bool] = None, | |
| details: Optional[bool] = None, | |
| max_new_tokens: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| return_full_text: Optional[bool] = None, | |
| seed: Optional[int] = None, | |
| stop: Optional[List[str]] = None, | |
| temperature: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| truncate: Optional[int] = None, | |
| typical_p: Optional[float] = None, | |
| watermark: Optional[bool] = None, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return super().get_config() | |
| def get_supported_openai_params(self, model: str): | |
| return [ | |
| "stream", | |
| "temperature", | |
| "max_completion_tokens", | |
| "max_tokens", | |
| "top_p", | |
| "stop", | |
| "n", | |
| "response_format", | |
| ] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| for param, value in non_default_params.items(): | |
| # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None | |
| if param == "temperature": | |
| if value == 0.0 or value == 0: | |
| # hugging face exception raised when temp==0 | |
| # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive | |
| value = 0.01 | |
| optional_params["temperature"] = value | |
| if param == "top_p": | |
| optional_params["top_p"] = value | |
| if param == "n": | |
| optional_params["best_of"] = value | |
| optional_params[ | |
| "do_sample" | |
| ] = True # Need to sample if you want best of for hf inference endpoints | |
| if param == "stream": | |
| optional_params["stream"] = value | |
| if param == "stop": | |
| optional_params["stop"] = value | |
| if param == "max_tokens" or param == "max_completion_tokens": | |
| # HF TGI raises the following exception when max_new_tokens==0 | |
| # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive | |
| if value == 0: | |
| value = 1 | |
| optional_params["max_new_tokens"] = value | |
| if param == "echo": | |
| # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details | |
| # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False | |
| optional_params["decoder_input_details"] = True | |
| if param == "response_format": | |
| optional_params["response_format"] = value | |
| return optional_params | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: str, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| raise NotImplementedError( | |
| "Predibase transformation currently done in handler.py. Need to migrate to this file." | |
| ) | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| raise NotImplementedError( | |
| "Predibase transformation currently done in handler.py. Need to migrate to this file." | |
| ) | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, Headers] | |
| ) -> BaseLLMException: | |
| return PredibaseError( | |
| status_code=status_code, message=error_message, headers=headers | |
| ) | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| if api_key is None: | |
| raise ValueError( | |
| "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" | |
| ) | |
| default_headers = { | |
| "content-type": "application/json", | |
| "Authorization": "Bearer {}".format(api_key), | |
| } | |
| if headers is not None and isinstance(headers, dict): | |
| headers = {**default_headers, **headers} | |
| return headers | |