Spaces:
Paused
Paused
| from typing import List, Optional, Union | |
| import httpx | |
| from litellm.llms.base_llm.chat.transformation import AllMessageValues, BaseLLMException | |
| from litellm.llms.base_llm.embedding.transformation import ( | |
| BaseEmbeddingConfig, | |
| LiteLLMLoggingObj, | |
| ) | |
| from litellm.types.llms.openai import AllEmbeddingInputValues | |
| from litellm.types.utils import EmbeddingResponse | |
| from ..common_utils import TritonError | |
| class TritonEmbeddingConfig(BaseEmbeddingConfig): | |
| """ | |
| Transformations for triton /embeddings endpoint (This is a trtllm model) | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def get_supported_openai_params(self, model: str) -> list: | |
| return [] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| """ | |
| Map OpenAI params to Triton Embedding params | |
| """ | |
| return optional_params | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| return {} | |
| def transform_embedding_request( | |
| self, | |
| model: str, | |
| input: AllEmbeddingInputValues, | |
| optional_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| return { | |
| "inputs": [ | |
| { | |
| "name": "input_text", | |
| "shape": [len(input)], | |
| "datatype": "BYTES", | |
| "data": input, | |
| } | |
| ] | |
| } | |
| def transform_embedding_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: EmbeddingResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| api_key: Optional[str] = None, | |
| request_data: dict = {}, | |
| optional_params: dict = {}, | |
| litellm_params: dict = {}, | |
| ) -> EmbeddingResponse: | |
| try: | |
| raw_response_json = raw_response.json() | |
| except Exception: | |
| raise TritonError( | |
| message=raw_response.text, status_code=raw_response.status_code | |
| ) | |
| _embedding_output = [] | |
| _outputs = raw_response_json["outputs"] | |
| for output in _outputs: | |
| _shape = output["shape"] | |
| _data = output["data"] | |
| _split_output_data = self.split_embedding_by_shape(_data, _shape) | |
| for idx, embedding in enumerate(_split_output_data): | |
| _embedding_output.append( | |
| { | |
| "object": "embedding", | |
| "index": idx, | |
| "embedding": embedding, | |
| } | |
| ) | |
| model_response.model = raw_response_json.get("model_name", "None") | |
| model_response.data = _embedding_output | |
| return model_response | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
| ) -> BaseLLMException: | |
| return TritonError( | |
| message=error_message, status_code=status_code, headers=headers | |
| ) | |
| def split_embedding_by_shape( | |
| data: List[float], shape: List[int] | |
| ) -> List[List[float]]: | |
| if len(shape) != 2: | |
| raise ValueError("Shape must be of length 2.") | |
| embedding_size = shape[1] | |
| return [ | |
| data[i * embedding_size : (i + 1) * embedding_size] for i in range(shape[0]) | |
| ] | |