Spaces:
Paused
Paused
| """ | |
| Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` | |
| In the Huggingface TGI format. | |
| """ | |
| import json | |
| import time | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |
| from httpx._models import Headers, Response | |
| import litellm | |
| from litellm.litellm_core_utils.asyncify import asyncify | |
| from litellm.litellm_core_utils.prompt_templates.factory import ( | |
| custom_prompt, | |
| prompt_factory, | |
| ) | |
| from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import ModelResponse, Usage | |
| from litellm.utils import token_counter | |
| from ..common_utils import SagemakerError | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
| LiteLLMLoggingObj = _LiteLLMLoggingObj | |
| else: | |
| LiteLLMLoggingObj = Any | |
| class SagemakerConfig(BaseConfig): | |
| """ | |
| Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb | |
| """ | |
| max_new_tokens: Optional[int] = None | |
| max_completion_tokens: Optional[int] = None | |
| top_p: Optional[float] = None | |
| temperature: Optional[float] = None | |
| return_full_text: Optional[bool] = None | |
| def __init__( | |
| self, | |
| max_new_tokens: Optional[int] = None, | |
| max_completion_tokens: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| temperature: Optional[float] = None, | |
| return_full_text: Optional[bool] = None, | |
| ) -> None: | |
| locals_ = locals().copy() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return super().get_config() | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, Headers] | |
| ) -> BaseLLMException: | |
| return SagemakerError( | |
| message=error_message, status_code=status_code, headers=headers | |
| ) | |
| def get_supported_openai_params(self, model: str) -> List: | |
| return ["stream", "temperature", "max_tokens", "max_completion_tokens", "top_p", "stop", "n"] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| for param, value in non_default_params.items(): | |
| if param == "temperature": | |
| if value == 0.0 or value == 0: | |
| # hugging face exception raised when temp==0 | |
| # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive | |
| if not non_default_params.get( | |
| "aws_sagemaker_allow_zero_temp", False | |
| ): | |
| value = 0.01 | |
| optional_params["temperature"] = value | |
| if param == "top_p": | |
| optional_params["top_p"] = value | |
| if param == "n": | |
| optional_params["best_of"] = value | |
| optional_params[ | |
| "do_sample" | |
| ] = True # Need to sample if you want best of for hf inference endpoints | |
| if param == "stream": | |
| optional_params["stream"] = value | |
| if param == "stop": | |
| optional_params["stop"] = value | |
| if param == "max_tokens": | |
| # HF TGI raises the following exception when max_new_tokens==0 | |
| # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive | |
| if value == 0: | |
| value = 1 | |
| optional_params["max_new_tokens"] = value | |
| if param == "max_completion_tokens": | |
| optional_params["max_new_tokens"] = value | |
| non_default_params.pop("aws_sagemaker_allow_zero_temp", None) | |
| return optional_params | |
| def _transform_prompt( | |
| self, | |
| model: str, | |
| messages: List, | |
| custom_prompt_dict: dict, | |
| hf_model_name: Optional[str], | |
| ) -> str: | |
| if model in custom_prompt_dict: | |
| # check if the model has a registered custom prompt | |
| model_prompt_details = custom_prompt_dict[model] | |
| prompt = custom_prompt( | |
| role_dict=model_prompt_details.get("roles", None), | |
| initial_prompt_value=model_prompt_details.get( | |
| "initial_prompt_value", "" | |
| ), | |
| final_prompt_value=model_prompt_details.get("final_prompt_value", ""), | |
| messages=messages, | |
| ) | |
| elif hf_model_name in custom_prompt_dict: | |
| # check if the base huggingface model has a registered custom prompt | |
| model_prompt_details = custom_prompt_dict[hf_model_name] | |
| prompt = custom_prompt( | |
| role_dict=model_prompt_details.get("roles", None), | |
| initial_prompt_value=model_prompt_details.get( | |
| "initial_prompt_value", "" | |
| ), | |
| final_prompt_value=model_prompt_details.get("final_prompt_value", ""), | |
| messages=messages, | |
| ) | |
| else: | |
| if hf_model_name is None: | |
| if "llama-2" in model.lower(): # llama-2 model | |
| if "chat" in model.lower(): # apply llama2 chat template | |
| hf_model_name = "meta-llama/Llama-2-7b-chat-hf" | |
| else: # apply regular llama2 template | |
| hf_model_name = "meta-llama/Llama-2-7b" | |
| hf_model_name = ( | |
| hf_model_name or model | |
| ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) | |
| prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore | |
| return prompt | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| inference_params = optional_params.copy() | |
| stream = inference_params.pop("stream", False) | |
| data: Dict = {"parameters": inference_params} | |
| if stream is True: | |
| data["stream"] = True | |
| custom_prompt_dict = ( | |
| litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict | |
| ) | |
| hf_model_name = litellm_params.get("hf_model_name", None) | |
| prompt = self._transform_prompt( | |
| model=model, | |
| messages=messages, | |
| custom_prompt_dict=custom_prompt_dict, | |
| hf_model_name=hf_model_name, | |
| ) | |
| data["inputs"] = prompt | |
| return data | |
| async def async_transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| return await asyncify(self.transform_request)( | |
| model, messages, optional_params, litellm_params, headers | |
| ) | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: str, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| completion_response = raw_response.json() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key="", | |
| original_response=completion_response, | |
| additional_args={"complete_input_dict": request_data}, | |
| ) | |
| prompt = request_data["inputs"] | |
| ## RESPONSE OBJECT | |
| try: | |
| if isinstance(completion_response, list): | |
| completion_response_choices = completion_response[0] | |
| else: | |
| completion_response_choices = completion_response | |
| completion_output = "" | |
| if "generation" in completion_response_choices: | |
| completion_output += completion_response_choices["generation"] | |
| elif "generated_text" in completion_response_choices: | |
| completion_output += completion_response_choices["generated_text"] | |
| # check if the prompt template is part of output, if so - filter it out | |
| if completion_output.startswith(prompt) and "<s>" in prompt: | |
| completion_output = completion_output.replace(prompt, "", 1) | |
| model_response.choices[0].message.content = completion_output # type: ignore | |
| except Exception: | |
| raise SagemakerError( | |
| message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", | |
| status_code=500, | |
| ) | |
| ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. | |
| prompt_tokens = token_counter( | |
| text=prompt, count_response_tokens=True | |
| ) # doesn't apply any default token count from openai's chat template | |
| completion_tokens = token_counter( | |
| text=model_response["choices"][0]["message"].get("content", ""), | |
| count_response_tokens=True, | |
| ) | |
| model_response.created = int(time.time()) | |
| model_response.model = model | |
| usage = Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens, | |
| ) | |
| setattr(model_response, "usage", usage) | |
| return model_response | |
| def validate_environment( | |
| self, | |
| headers: Optional[dict], | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| headers = {"Content-Type": "application/json"} | |
| if headers is not None: | |
| headers = {"Content-Type": "application/json", **headers} | |
| return headers | |