Spaces:
Paused
Paused
| import json | |
| from typing import Any, Callable, Optional, Union | |
| import httpx | |
| import litellm | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| AsyncHTTPHandler, | |
| HTTPHandler, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.types.llms.bedrock import CohereEmbeddingRequest | |
| from litellm.types.utils import EmbeddingResponse | |
| from .transformation import CohereEmbeddingConfig | |
| def validate_environment(api_key, headers: dict): | |
| headers.update( | |
| { | |
| "Request-Source": "unspecified:litellm", | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| } | |
| ) | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| return headers | |
| class CohereError(Exception): | |
| def __init__(self, status_code, message): | |
| self.status_code = status_code | |
| self.message = message | |
| self.request = httpx.Request( | |
| method="POST", url="https://api.cohere.ai/v1/generate" | |
| ) | |
| self.response = httpx.Response(status_code=status_code, request=self.request) | |
| super().__init__( | |
| self.message | |
| ) # Call the base class constructor with the parameters it needs | |
| async def async_embedding( | |
| model: str, | |
| data: Union[dict, CohereEmbeddingRequest], | |
| input: list, | |
| model_response: litellm.utils.EmbeddingResponse, | |
| timeout: Optional[Union[float, httpx.Timeout]], | |
| logging_obj: LiteLLMLoggingObj, | |
| optional_params: dict, | |
| api_base: str, | |
| api_key: Optional[str], | |
| headers: dict, | |
| encoding: Callable, | |
| client: Optional[AsyncHTTPHandler] = None, | |
| ): | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={ | |
| "complete_input_dict": data, | |
| "headers": headers, | |
| "api_base": api_base, | |
| }, | |
| ) | |
| ## COMPLETION CALL | |
| if client is None: | |
| client = get_async_httpx_client( | |
| llm_provider=litellm.LlmProviders.COHERE, | |
| params={"timeout": timeout}, | |
| ) | |
| try: | |
| response = await client.post(api_base, headers=headers, data=json.dumps(data)) | |
| except httpx.HTTPStatusError as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=e.response.text, | |
| ) | |
| raise e | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| ## PROCESS RESPONSE ## | |
| return CohereEmbeddingConfig()._transform_response( | |
| response=response, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| data=data, | |
| model_response=model_response, | |
| model=model, | |
| encoding=encoding, | |
| input=input, | |
| ) | |
| def embedding( | |
| model: str, | |
| input: list, | |
| model_response: EmbeddingResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| optional_params: dict, | |
| headers: dict, | |
| encoding: Any, | |
| data: Optional[Union[dict, CohereEmbeddingRequest]] = None, | |
| complete_api_base: Optional[str] = None, | |
| api_key: Optional[str] = None, | |
| aembedding: Optional[bool] = None, | |
| timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None), | |
| client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
| ): | |
| headers = validate_environment(api_key, headers=headers) | |
| embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" | |
| model = model | |
| data = data or CohereEmbeddingConfig()._transform_request( | |
| model=model, input=input, inference_params=optional_params | |
| ) | |
| ## ROUTING | |
| if aembedding is True: | |
| return async_embedding( | |
| model=model, | |
| data=data, | |
| input=input, | |
| model_response=model_response, | |
| timeout=timeout, | |
| logging_obj=logging_obj, | |
| optional_params=optional_params, | |
| api_base=embed_url, | |
| api_key=api_key, | |
| headers=headers, | |
| encoding=encoding, | |
| client=( | |
| client | |
| if client is not None and isinstance(client, AsyncHTTPHandler) | |
| else None | |
| ), | |
| ) | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| ## COMPLETION CALL | |
| if client is None or not isinstance(client, HTTPHandler): | |
| client = HTTPHandler(concurrent_limit=1) | |
| response = client.post(embed_url, headers=headers, data=json.dumps(data)) | |
| return CohereEmbeddingConfig()._transform_response( | |
| response=response, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| data=data, | |
| model_response=model_response, | |
| model=model, | |
| encoding=encoding, | |
| input=input, | |
| ) | |