Spaces:
Paused
Paused
| import json | |
| from typing import List, Optional | |
| from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import ( | |
| ChatCompletionToolCallChunk, | |
| ChatCompletionUsageBlock, | |
| GenericStreamingChunk, | |
| ) | |
| class CohereError(BaseLLMException): | |
| def __init__(self, status_code, message): | |
| super().__init__(status_code=status_code, message=message) | |
| def validate_environment( | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| api_key: Optional[str] = None, | |
| ) -> dict: | |
| """ | |
| Return headers to use for cohere chat completion request | |
| Cohere API Ref: https://docs.cohere.com/reference/chat | |
| Expected headers: | |
| { | |
| "Request-Source": "unspecified:litellm", | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| "Authorization": "bearer $CO_API_KEY" | |
| } | |
| """ | |
| headers.update( | |
| { | |
| "Request-Source": "unspecified:litellm", | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| } | |
| ) | |
| if api_key: | |
| headers["Authorization"] = f"bearer {api_key}" | |
| return headers | |
| class ModelResponseIterator: | |
| def __init__( | |
| self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False | |
| ): | |
| self.streaming_response = streaming_response | |
| self.response_iterator = self.streaming_response | |
| self.content_blocks: List = [] | |
| self.tool_index = -1 | |
| self.json_mode = json_mode | |
| def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: | |
| try: | |
| text = "" | |
| tool_use: Optional[ChatCompletionToolCallChunk] = None | |
| is_finished = False | |
| finish_reason = "" | |
| usage: Optional[ChatCompletionUsageBlock] = None | |
| provider_specific_fields = None | |
| index = int(chunk.get("index", 0)) | |
| if "text" in chunk: | |
| text = chunk["text"] | |
| elif "is_finished" in chunk and chunk["is_finished"] is True: | |
| is_finished = chunk["is_finished"] | |
| finish_reason = chunk["finish_reason"] | |
| if "citations" in chunk: | |
| provider_specific_fields = {"citations": chunk["citations"]} | |
| returned_chunk = GenericStreamingChunk( | |
| text=text, | |
| tool_use=tool_use, | |
| is_finished=is_finished, | |
| finish_reason=finish_reason, | |
| usage=usage, | |
| index=index, | |
| provider_specific_fields=provider_specific_fields, | |
| ) | |
| return returned_chunk | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Failed to decode JSON from chunk: {chunk}") | |
| # Sync iterator | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| try: | |
| chunk = self.response_iterator.__next__() | |
| except StopIteration: | |
| raise StopIteration | |
| except ValueError as e: | |
| raise RuntimeError(f"Error receiving chunk from stream: {e}") | |
| try: | |
| return self.convert_str_chunk_to_generic_chunk(chunk=chunk) | |
| except StopIteration: | |
| raise StopIteration | |
| except ValueError as e: | |
| raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") | |
| def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk: | |
| """ | |
| Convert a string chunk to a GenericStreamingChunk | |
| Note: This is used for Cohere pass through streaming logging | |
| """ | |
| str_line = chunk | |
| if isinstance(chunk, bytes): # Handle binary data | |
| str_line = chunk.decode("utf-8") # Convert bytes to string | |
| index = str_line.find("data:") | |
| if index != -1: | |
| str_line = str_line[index:] | |
| data_json = json.loads(str_line) | |
| return self.chunk_parser(chunk=data_json) | |
| # Async iterator | |
| def __aiter__(self): | |
| self.async_response_iterator = self.streaming_response.__aiter__() | |
| return self | |
| async def __anext__(self): | |
| try: | |
| chunk = await self.async_response_iterator.__anext__() | |
| except StopAsyncIteration: | |
| raise StopAsyncIteration | |
| except ValueError as e: | |
| raise RuntimeError(f"Error receiving chunk from stream: {e}") | |
| try: | |
| return self.convert_str_chunk_to_generic_chunk(chunk=chunk) | |
| except StopAsyncIteration: | |
| raise StopAsyncIteration | |
| except ValueError as e: | |
| raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") | |