Spaces:
Paused
Paused
| # What is this? | |
| ## handler file for TextCompletionCodestral Integration - https://codestral.com/ | |
| import json | |
| from functools import partial | |
| from typing import Callable, List, Optional, Union | |
| import httpx # type: ignore | |
| import litellm | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging | |
| from litellm.litellm_core_utils.prompt_templates.factory import ( | |
| custom_prompt, | |
| prompt_factory, | |
| ) | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| AsyncHTTPHandler, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.types.utils import TextChoices | |
| from litellm.utils import CustomStreamWrapper, TextCompletionResponse | |
| class TextCompletionCodestralError(Exception): | |
| def __init__( | |
| self, | |
| status_code, | |
| message, | |
| request: Optional[httpx.Request] = None, | |
| response: Optional[httpx.Response] = None, | |
| ): | |
| self.status_code = status_code | |
| self.message = message | |
| if request is not None: | |
| self.request = request | |
| else: | |
| self.request = httpx.Request( | |
| method="POST", | |
| url="https://docs.codestral.com/user-guide/inference/rest_api", | |
| ) | |
| if response is not None: | |
| self.response = response | |
| else: | |
| self.response = httpx.Response( | |
| status_code=status_code, request=self.request | |
| ) | |
| super().__init__( | |
| self.message | |
| ) # Call the base class constructor with the parameters it needs | |
| async def make_call( | |
| client: AsyncHTTPHandler, | |
| api_base: str, | |
| headers: dict, | |
| data: str, | |
| model: str, | |
| messages: list, | |
| logging_obj, | |
| ): | |
| response = await client.post(api_base, headers=headers, data=data, stream=True) | |
| if response.status_code != 200: | |
| raise TextCompletionCodestralError( | |
| status_code=response.status_code, message=response.text | |
| ) | |
| completion_stream = response.aiter_lines() | |
| # LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key="", | |
| original_response=completion_stream, # Pass the completion stream for logging | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| return completion_stream | |
| class CodestralTextCompletion: | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def _validate_environment( | |
| self, | |
| api_key: Optional[str], | |
| user_headers: dict, | |
| ) -> dict: | |
| if api_key is None: | |
| raise ValueError( | |
| "Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables" | |
| ) | |
| headers = { | |
| "content-type": "application/json", | |
| "Authorization": "Bearer {}".format(api_key), | |
| } | |
| if user_headers is not None and isinstance(user_headers, dict): | |
| headers = {**headers, **user_headers} | |
| return headers | |
| def output_parser(self, generated_text: str): | |
| """ | |
| Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. | |
| Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 | |
| """ | |
| chat_template_tokens = [ | |
| "<|assistant|>", | |
| "<|system|>", | |
| "<|user|>", | |
| "<s>", | |
| "</s>", | |
| ] | |
| for token in chat_template_tokens: | |
| if generated_text.strip().startswith(token): | |
| generated_text = generated_text.replace(token, "", 1) | |
| if generated_text.endswith(token): | |
| generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] | |
| return generated_text | |
| def process_text_completion_response( | |
| self, | |
| model: str, | |
| response: httpx.Response, | |
| model_response: TextCompletionResponse, | |
| stream: bool, | |
| logging_obj: LiteLLMLogging, | |
| optional_params: dict, | |
| api_key: str, | |
| data: Union[dict, str], | |
| messages: list, | |
| print_verbose, | |
| encoding, | |
| ) -> TextCompletionResponse: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key=api_key, | |
| original_response=response.text, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| print_verbose(f"codestral api: raw model_response: {response.text}") | |
| ## RESPONSE OBJECT | |
| if response.status_code != 200: | |
| raise TextCompletionCodestralError( | |
| message=str(response.text), | |
| status_code=response.status_code, | |
| ) | |
| try: | |
| completion_response = response.json() | |
| except Exception: | |
| raise TextCompletionCodestralError(message=response.text, status_code=422) | |
| _original_choices = completion_response.get("choices", []) | |
| _choices: List[TextChoices] = [] | |
| for choice in _original_choices: | |
| # This is what 1 choice looks like from codestral API | |
| # { | |
| # "index": 0, | |
| # "message": { | |
| # "role": "assistant", | |
| # "content": "\n assert is_odd(1)\n assert", | |
| # "tool_calls": null | |
| # }, | |
| # "finish_reason": "length", | |
| # "logprobs": null | |
| # } | |
| _finish_reason = None | |
| _index = 0 | |
| _text = None | |
| _logprobs = None | |
| _choice_message = choice.get("message", {}) | |
| _choice = litellm.utils.TextChoices( | |
| finish_reason=choice.get("finish_reason"), | |
| index=choice.get("index"), | |
| text=_choice_message.get("content"), | |
| logprobs=choice.get("logprobs"), | |
| ) | |
| _choices.append(_choice) | |
| _response = litellm.TextCompletionResponse( | |
| id=completion_response.get("id"), | |
| choices=_choices, | |
| created=completion_response.get("created"), | |
| model=completion_response.get("model"), | |
| usage=completion_response.get("usage"), | |
| stream=False, | |
| object=completion_response.get("object"), | |
| ) | |
| return _response | |
| def completion( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| custom_prompt_dict: dict, | |
| model_response: TextCompletionResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key: str, | |
| logging_obj, | |
| optional_params: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| acompletion=None, | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers: dict = {}, | |
| ) -> Union[TextCompletionResponse, CustomStreamWrapper]: | |
| headers = self._validate_environment(api_key, headers) | |
| if optional_params.pop("custom_endpoint", None) is True: | |
| completion_url = api_base | |
| else: | |
| completion_url = ( | |
| api_base or "https://codestral.mistral.ai/v1/fim/completions" | |
| ) | |
| if model in custom_prompt_dict: | |
| # check if the model has a registered custom prompt | |
| model_prompt_details = custom_prompt_dict[model] | |
| prompt = custom_prompt( | |
| role_dict=model_prompt_details["roles"], | |
| initial_prompt_value=model_prompt_details["initial_prompt_value"], | |
| final_prompt_value=model_prompt_details["final_prompt_value"], | |
| messages=messages, | |
| ) | |
| else: | |
| prompt = prompt_factory(model=model, messages=messages) | |
| ## Load Config | |
| config = litellm.CodestralTextCompletionConfig.get_config() | |
| for k, v in config.items(): | |
| if ( | |
| k not in optional_params | |
| ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in | |
| optional_params[k] = v | |
| stream = optional_params.pop("stream", False) | |
| data = { | |
| "model": model, | |
| "prompt": prompt, | |
| **optional_params, | |
| } | |
| input_text = prompt | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input_text, | |
| api_key=api_key, | |
| additional_args={ | |
| "complete_input_dict": data, | |
| "headers": headers, | |
| "api_base": completion_url, | |
| "acompletion": acompletion, | |
| }, | |
| ) | |
| ## COMPLETION CALL | |
| if acompletion is True: | |
| ### ASYNC STREAMING | |
| if stream is True: | |
| return self.async_streaming( | |
| model=model, | |
| messages=messages, | |
| data=data, | |
| api_base=completion_url, | |
| model_response=model_response, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| logger_fn=logger_fn, | |
| headers=headers, | |
| timeout=timeout, | |
| ) # type: ignore | |
| else: | |
| ### ASYNC COMPLETION | |
| return self.async_completion( | |
| model=model, | |
| messages=messages, | |
| data=data, | |
| api_base=completion_url, | |
| model_response=model_response, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| api_key=api_key, | |
| logging_obj=logging_obj, | |
| optional_params=optional_params, | |
| stream=False, | |
| litellm_params=litellm_params, | |
| logger_fn=logger_fn, | |
| headers=headers, | |
| timeout=timeout, | |
| ) # type: ignore | |
| ### SYNC STREAMING | |
| if stream is True: | |
| response = litellm.module_level_client.post( | |
| completion_url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| stream=stream, | |
| ) | |
| _response = CustomStreamWrapper( | |
| response.iter_lines(), | |
| model, | |
| custom_llm_provider="codestral", | |
| logging_obj=logging_obj, | |
| ) | |
| return _response | |
| ### SYNC COMPLETION | |
| else: | |
| response = litellm.module_level_client.post( | |
| url=completion_url, | |
| headers=headers, | |
| data=json.dumps(data), | |
| ) | |
| return self.process_text_completion_response( | |
| model=model, | |
| response=response, | |
| model_response=model_response, | |
| stream=optional_params.get("stream", False), | |
| logging_obj=logging_obj, # type: ignore | |
| optional_params=optional_params, | |
| api_key=api_key, | |
| data=data, | |
| messages=messages, | |
| print_verbose=print_verbose, | |
| encoding=encoding, | |
| ) | |
| async def async_completion( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| model_response: TextCompletionResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key, | |
| logging_obj, | |
| stream, | |
| data: dict, | |
| optional_params: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers={}, | |
| ) -> TextCompletionResponse: | |
| async_handler = get_async_httpx_client( | |
| llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, | |
| params={"timeout": timeout}, | |
| ) | |
| try: | |
| response = await async_handler.post( | |
| api_base, headers=headers, data=json.dumps(data) | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| raise TextCompletionCodestralError( | |
| status_code=e.response.status_code, | |
| message="HTTPStatusError - {}".format(e.response.text), | |
| ) | |
| except Exception as e: | |
| raise TextCompletionCodestralError( | |
| status_code=500, message="{}".format(str(e)) | |
| ) # don't use verbose_logger.exception, if exception is raised | |
| return self.process_text_completion_response( | |
| model=model, | |
| response=response, | |
| model_response=model_response, | |
| stream=stream, | |
| logging_obj=logging_obj, | |
| api_key=api_key, | |
| data=data, | |
| messages=messages, | |
| print_verbose=print_verbose, | |
| optional_params=optional_params, | |
| encoding=encoding, | |
| ) | |
| async def async_streaming( | |
| self, | |
| model: str, | |
| messages: list, | |
| api_base: str, | |
| model_response: TextCompletionResponse, | |
| print_verbose: Callable, | |
| encoding, | |
| api_key, | |
| logging_obj, | |
| data: dict, | |
| timeout: Union[float, httpx.Timeout], | |
| optional_params=None, | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers={}, | |
| ) -> CustomStreamWrapper: | |
| data["stream"] = True | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=None, | |
| make_call=partial( | |
| make_call, | |
| api_base=api_base, | |
| headers=headers, | |
| data=json.dumps(data), | |
| model=model, | |
| messages=messages, | |
| logging_obj=logging_obj, | |
| ), | |
| model=model, | |
| custom_llm_provider="text-completion-codestral", | |
| logging_obj=logging_obj, | |
| ) | |
| return streamwrapper | |
| def embedding(self, *args, **kwargs): | |
| pass | |