Spaces:
Paused
Paused
| # What is this? | |
| ## Controller file for Predibase Integration - https://predibase.com/ | |
| import json | |
| import os | |
| import time | |
| from functools import partial | |
| from typing import Callable, Optional, Union | |
| import httpx # type: ignore | |
| import litellm | |
| import litellm.litellm_core_utils | |
| import litellm.litellm_core_utils.litellm_logging | |
| from litellm.litellm_core_utils.core_helpers import map_finish_reason | |
| from litellm.litellm_core_utils.prompt_templates.factory import ( | |
| custom_prompt, | |
| prompt_factory, | |
| ) | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| AsyncHTTPHandler, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.types.utils import LiteLLMLoggingBaseClass | |
| from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage | |
| from ..common_utils import PredibaseError | |
| async def make_call( | |
| client: AsyncHTTPHandler, | |
| api_base: str, | |
| headers: dict, | |
| data: str, | |
| model: str, | |
| messages: list, | |
| logging_obj, | |
| timeout: Optional[Union[float, httpx.Timeout]], | |
| ): | |
| response = await client.post( | |
| api_base, headers=headers, data=data, stream=True, timeout=timeout | |
| ) | |
| if response.status_code != 200: | |
| raise PredibaseError(status_code=response.status_code, message=response.text) | |
| completion_stream = response.aiter_lines() | |
| # LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key="", | |
| original_response=completion_stream, # Pass the completion stream for logging | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| return completion_stream | |
| class PredibaseChatCompletion: | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def output_parser(self, generated_text: str): | |
| """ | |
| Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. | |
| Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 | |
| """ | |
| chat_template_tokens = [ | |
| "<|assistant|>", | |
| "<|system|>", | |
| "<|user|>", | |
| "<s>", | |
| "</s>", | |
| ] | |
| for token in chat_template_tokens: | |
| if generated_text.strip().startswith(token): | |
| generated_text = generated_text.replace(token, "", 1) | |
| if generated_text.endswith(token): | |
| generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] | |
| return generated_text | |
| def process_response( # noqa: PLR0915 | |
| self, | |
| model: str, | |
| response: httpx.Response, | |
| model_response: ModelResponse, | |
| stream: bool, | |
| logging_obj: LiteLLMLoggingBaseClass, | |
| optional_params: dict, | |
| api_key: str, | |
| data: Union[dict, str], | |
| messages: list, | |
| print_verbose, | |
| encoding, | |
| ) -> ModelResponse: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key=api_key, | |
| original_response=response.text, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| print_verbose(f"raw model_response: {response.text}") | |
| ## RESPONSE OBJECT | |
| try: | |
| completion_response = response.json() | |
| except Exception: | |
| raise PredibaseError(message=response.text, status_code=422) | |
| if "error" in completion_response: | |
| raise PredibaseError( | |
| message=str(completion_response["error"]), | |
| status_code=response.status_code, | |
| ) | |
| else: | |
| if not isinstance(completion_response, dict): | |
| raise PredibaseError( | |
| status_code=422, | |
| message=f"'completion_response' is not a dictionary - {completion_response}", | |
| ) | |
| elif "generated_text" not in completion_response: | |
| raise PredibaseError( | |
| status_code=422, | |
| message=f"'generated_text' is not a key response dictionary - {completion_response}", | |
| ) | |
| if len(completion_response["generated_text"]) > 0: | |
| model_response.choices[0].message.content = self.output_parser( # type: ignore | |
| completion_response["generated_text"] | |
| ) | |
| ## GETTING LOGPROBS + FINISH REASON | |
| if ( | |
| "details" in completion_response | |
| and "tokens" in completion_response["details"] | |
| ): | |
| model_response.choices[0].finish_reason = map_finish_reason( | |
| completion_response["details"]["finish_reason"] | |
| ) | |
| sum_logprob = 0 | |
| for token in completion_response["details"]["tokens"]: | |
| if token["logprob"] is not None: | |
| sum_logprob += token["logprob"] | |
| setattr( | |
| model_response.choices[0].message, # type: ignore | |
| "_logprob", | |
| sum_logprob, # [TODO] move this to using the actual logprobs | |
| ) | |
| if "best_of" in optional_params and optional_params["best_of"] > 1: | |
| if ( | |
| "details" in completion_response | |
| and "best_of_sequences" in completion_response["details"] | |
| ): | |
| choices_list = [] | |
| for idx, item in enumerate( | |
| completion_response["details"]["best_of_sequences"] | |
| ): | |
| sum_logprob = 0 | |
| for token in item["tokens"]: | |
| if token["logprob"] is not None: | |
| sum_logprob += token["logprob"] | |
| if len(item["generated_text"]) > 0: | |
| message_obj = Message( | |
| content=self.output_parser(item["generated_text"]), | |
| logprobs=sum_logprob, | |
| ) | |
| else: | |
| message_obj = Message(content=None) | |
| choice_obj = Choices( | |
| finish_reason=map_finish_reason(item["finish_reason"]), | |
| index=idx + 1, | |
| message=message_obj, | |
| ) | |
| choices_list.append(choice_obj) | |
| model_response.choices.extend(choices_list) | |
| ## CALCULATING USAGE | |
| prompt_tokens = 0 | |
| try: | |
| prompt_tokens = litellm.token_counter(messages=messages) | |
| except Exception: | |
| # this should remain non blocking we should not block a response returning if calculating usage fails | |
| pass | |
| output_text = model_response["choices"][0]["message"].get("content", "") | |
| if output_text is not None and len(output_text) > 0: | |
| completion_tokens = 0 | |
| try: | |
| completion_tokens = len( | |
| encoding.encode( | |
| model_response["choices"][0]["message"].get("content", "") | |
| ) | |
| ) ##[TODO] use a model-specific tokenizer | |
| except Exception: | |
| # this should remain non blocking we should not block a response returning if calculating usage fails | |
| pass | |
| else: | |
| completion_tokens = 0 | |
| total_tokens = prompt_tokens + completion_tokens | |
| model_response.created = int(time.time()) | |
| model_response.model = model | |
| usage = Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=total_tokens, | |
| ) | |
| model_response.usage = usage # type: ignore | |
| ## RESPONSE HEADERS | |
| predibase_headers = response.headers | |
| response_headers = {} | |
| for k, v in predibase_headers.items(): | |
| if k.startswith("x-"): | |
| response_headers["llm_provider-{}".format(k)] = v | |
| model_response._hidden_params["additional_headers"] = response_headers | |
| return model_response | |
| def completion( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| custom_prompt_dict: dict, | |
| model_response: ModelResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key: str, | |
| logging_obj, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| tenant_id: str, | |
| timeout: Union[float, httpx.Timeout], | |
| acompletion=None, | |
| logger_fn=None, | |
| headers: dict = {}, | |
| ) -> Union[ModelResponse, CustomStreamWrapper]: | |
| headers = litellm.PredibaseConfig().validate_environment( | |
| api_key=api_key, | |
| headers=headers, | |
| messages=messages, | |
| optional_params=optional_params, | |
| model=model, | |
| litellm_params=litellm_params, | |
| ) | |
| completion_url = "" | |
| input_text = "" | |
| base_url = "https://serving.app.predibase.com" | |
| if "https" in model: | |
| completion_url = model | |
| elif api_base: | |
| base_url = api_base | |
| elif "PREDIBASE_API_BASE" in os.environ: | |
| base_url = os.getenv("PREDIBASE_API_BASE", "") | |
| completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" | |
| if optional_params.get("stream", False) is True: | |
| completion_url += "/generate_stream" | |
| else: | |
| completion_url += "/generate" | |
| if model in custom_prompt_dict: | |
| # check if the model has a registered custom prompt | |
| model_prompt_details = custom_prompt_dict[model] | |
| prompt = custom_prompt( | |
| role_dict=model_prompt_details["roles"], | |
| initial_prompt_value=model_prompt_details["initial_prompt_value"], | |
| final_prompt_value=model_prompt_details["final_prompt_value"], | |
| messages=messages, | |
| ) | |
| else: | |
| prompt = prompt_factory(model=model, messages=messages) | |
| ## Load Config | |
| config = litellm.PredibaseConfig.get_config() | |
| for k, v in config.items(): | |
| if ( | |
| k not in optional_params | |
| ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in | |
| optional_params[k] = v | |
| stream = optional_params.pop("stream", False) | |
| data = { | |
| "inputs": prompt, | |
| "parameters": optional_params, | |
| } | |
| input_text = prompt | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input_text, | |
| api_key=api_key, | |
| additional_args={ | |
| "complete_input_dict": data, | |
| "headers": headers, | |
| "api_base": completion_url, | |
| "acompletion": acompletion, | |
| }, | |
| ) | |
| ## COMPLETION CALL | |
| if acompletion is True: | |
| ### ASYNC STREAMING | |
| if stream is True: | |
| return self.async_streaming( | |
| model=model, | |
| messages=messages, | |
| data=data, | |
| api_base=completion_url, | |
| model_response=model_response, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| logger_fn=logger_fn, | |
| headers=headers, | |
| timeout=timeout, | |
| ) # type: ignore | |
| else: | |
| ### ASYNC COMPLETION | |
| return self.async_completion( | |
| model=model, | |
| messages=messages, | |
| data=data, | |
| api_base=completion_url, | |
| model_response=model_response, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| optional_params=optional_params, | |
| stream=False, | |
| litellm_params=litellm_params, | |
| logger_fn=logger_fn, | |
| headers=headers, | |
| timeout=timeout, | |
| ) # type: ignore | |
| ### SYNC STREAMING | |
| if stream is True: | |
| response = litellm.module_level_client.post( | |
| completion_url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| stream=stream, | |
| timeout=timeout, # type: ignore | |
| ) | |
| _response = CustomStreamWrapper( | |
| response.iter_lines(), | |
| model, | |
| custom_llm_provider="predibase", | |
| logging_obj=logging_obj, | |
| ) | |
| return _response | |
| ### SYNC COMPLETION | |
| else: | |
| response = litellm.module_level_client.post( | |
| url=completion_url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| timeout=timeout, # type: ignore | |
| ) | |
| return self.process_response( | |
| model=model, | |
| response=response, | |
| model_response=model_response, | |
| stream=optional_params.get("stream", False), | |
| logging_obj=logging_obj, # type: ignore | |
| optional_params=optional_params, | |
| api_key=api_key, | |
| data=data, | |
| messages=messages, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| ) | |
| async def async_completion( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| model_response: ModelResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key, | |
| logging_obj, | |
| stream, | |
| data: dict, | |
| optional_params: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers={}, | |
| ) -> ModelResponse: | |
| async_handler = get_async_httpx_client( | |
| llm_provider=litellm.LlmProviders.PREDIBASE, | |
| params={"timeout": timeout}, | |
| ) | |
| try: | |
| response = await async_handler.post( | |
| api_base, headers=headers, data=json.dumps(data) | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| raise PredibaseError( | |
| status_code=e.response.status_code, | |
| message="HTTPStatusError - received status_code={}, error_message={}".format( | |
| e.response.status_code, e.response.text | |
| ), | |
| ) | |
| except Exception as e: | |
| for exception in litellm.LITELLM_EXCEPTION_TYPES: | |
| if isinstance(e, exception): | |
| raise e | |
| raise PredibaseError( | |
| status_code=500, message="{}".format(str(e)) | |
| ) # don't use verbose_logger.exception, if exception is raised | |
| return self.process_response( | |
| model=model, | |
| response=response, | |
| model_response=model_response, | |
| stream=stream, | |
| logging_obj=logging_obj, | |
| api_key=api_key, | |
| data=data, | |
| messages=messages, | |
| print_verbose=print_verbose, | |
| optional_params=optional_params, | |
| encoding=encoding, | |
| ) | |
| async def async_streaming( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| model_response: ModelResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key, | |
| logging_obj, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| optional_params=None, | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers={}, | |
| ) -> CustomStreamWrapper: | |
| data["stream"] = True | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=None, | |
| make_call=partial( | |
| make_call, | |
| api_base=api_base, | |
| headers=headers, | |
| data=json.dumps(data), | |
| model=model, | |
| messages=messages, | |
| logging_obj=logging_obj, | |
| timeout=timeout, | |
| ), | |
| model=model, | |
| custom_llm_provider="predibase", | |
| logging_obj=logging_obj, | |
| ) | |
| return streamwrapper | |
| def embedding(self, *args, **kwargs): | |
| pass | |