Spaces:
Paused
Paused
| from typing import List, Optional, Union | |
| import httpx | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
| from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig | |
| from litellm.secret_managers.main import get_secret_str | |
| from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues | |
| from litellm.types.utils import EmbeddingResponse, Usage | |
| from ..common_utils import InfinityError | |
| class InfinityEmbeddingConfig(BaseEmbeddingConfig): | |
| """ | |
| Reference: https://infinity.modal.michaelfeil.eu/docs | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def get_complete_url( | |
| self, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| model: str, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| stream: Optional[bool] = None, | |
| ) -> str: | |
| if api_base is None: | |
| raise ValueError("api_base is required for Infinity embeddings") | |
| # Remove trailing slashes and ensure clean base URL | |
| api_base = api_base.rstrip("/") | |
| if not api_base.endswith("/embeddings"): | |
| api_base = f"{api_base}/embeddings" | |
| return api_base | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| if api_key is None: | |
| api_key = get_secret_str("INFINITY_API_KEY") | |
| default_headers = { | |
| "Authorization": f"Bearer {api_key}", | |
| "accept": "application/json", | |
| "Content-Type": "application/json", | |
| } | |
| # If 'Authorization' is provided in headers, it overrides the default. | |
| if "Authorization" in headers: | |
| default_headers["Authorization"] = headers["Authorization"] | |
| # Merge other headers, overriding any default ones except Authorization | |
| return {**default_headers, **headers} | |
| def get_supported_openai_params(self, model: str) -> list: | |
| return [ | |
| "encoding_format", | |
| "modality", | |
| "dimensions", | |
| ] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| """ | |
| Map OpenAI params to Infinity params | |
| Reference: https://infinity.modal.michaelfeil.eu/docs | |
| """ | |
| if "encoding_format" in non_default_params: | |
| optional_params["encoding_format"] = non_default_params["encoding_format"] | |
| if "modality" in non_default_params: | |
| optional_params["modality"] = non_default_params["modality"] | |
| if "dimensions" in non_default_params: | |
| optional_params["output_dimension"] = non_default_params["dimensions"] | |
| return optional_params | |
| def transform_embedding_request( | |
| self, | |
| model: str, | |
| input: AllEmbeddingInputValues, | |
| optional_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| return { | |
| "input": input, | |
| "model": model, | |
| **optional_params, | |
| } | |
| def transform_embedding_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: EmbeddingResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| api_key: Optional[str] = None, | |
| request_data: dict = {}, | |
| optional_params: dict = {}, | |
| litellm_params: dict = {}, | |
| ) -> EmbeddingResponse: | |
| try: | |
| raw_response_json = raw_response.json() | |
| except Exception: | |
| raise InfinityError( | |
| message=raw_response.text, status_code=raw_response.status_code | |
| ) | |
| # model_response.usage | |
| model_response.model = raw_response_json.get("model") | |
| model_response.data = raw_response_json.get("data") | |
| model_response.object = raw_response_json.get("object") | |
| usage = Usage( | |
| prompt_tokens=raw_response_json.get("usage", {}).get("prompt_tokens", 0), | |
| total_tokens=raw_response_json.get("usage", {}).get("total_tokens", 0), | |
| ) | |
| model_response.usage = usage | |
| return model_response | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
| ) -> BaseLLMException: | |
| return InfinityError( | |
| message=error_message, status_code=status_code, headers=headers | |
| ) | |