| import json |
| import logging |
| import os |
| from abc import ABC |
| from typing import Callable, List |
|
|
| import openai |
| from tenacity import ( |
| before_sleep_log, |
| retry, |
| stop_after_attempt, |
| wait_random_exponential, |
| ) |
|
|
| from ..base_llm import BaseLLM |
| from ...schemas import * |
|
|
| logger = logging.getLogger(__name__) |
|
|
| MAX_PROMPT_LENGTH = 7000 |
|
|
|
|
| @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True, |
| before_sleep=before_sleep_log(logger, logging.WARNING)) |
| def chatcompletion_with_backoff(**kwargs): |
| return openai.ChatCompletion.create(**kwargs) |
|
|
|
|
| @retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True, |
| before_sleep=before_sleep_log(logger, logging.WARNING)) |
| async def async_chatcompletion_with_backoff(**kwargs): |
| async def _internal_coroutine(): |
| return await openai.ChatCompletion.acreate(**kwargs) |
|
|
| return await _internal_coroutine() |
|
|
|
|
| class OptOpenAIClient(BaseLLM, ABC): |
| """ |
| Wrapper class for OpenAI GPT API collections. |
| |
| :param model_name: The name of the model to use. |
| :type model_name: str |
| :param params: The parameters for the model. |
| :type params: OptParamModel |
| """ |
|
|
| model_name: str |
| params: OptParamModel = OptParamModel() |
|
|
| def __init__(self, **data): |
| super().__init__(**data) |
| openai.api_key = "EMPTY" |
| openai.api_base = "http://localhost:8000/v1" |
|
|
| @classmethod |
| async def create(cls, config_data): |
| return OptOpenAIClient(**config_data) |
|
|
| def get_model_name(self) -> str: |
| return self.model_name |
| |
| def get_model_param(self) -> OptParamModel: |
| return self.params |
|
|
| def completion(self, prompt: str, **kwargs) -> BaseCompletion: |
| """ |
| Completion method for OpenAI GPT API. |
| |
| :param prompt: The prompt to use for completion. |
| :type prompt: str |
| :param kwargs: Additional keyword arguments. |
| :type kwargs: dict |
| :return: BaseCompletion object. |
| :rtype: BaseCompletion |
| |
| """ |
|
|
| response = chatcompletion_with_backoff( |
| model=self.model_name, |
| |
| messages=[ |
| {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} |
| ], |
| timeout=1000, |
| **kwargs |
| ) |
|
|
| return BaseCompletion(state="success", |
| content=response.choices[0].message["content"], |
| prompt_token=response.get("usage", {}).get("prompt_tokens", 0), |
| completion_token=response.get("usage", {}).get("completion_tokens", 0)) |
|
|
| async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion: |
| """ |
| Completion method for OpenAI GPT API. |
| |
| :param prompt: The prompt to use for completion. |
| :type prompt: str |
| :param kwargs: Additional keyword arguments. |
| :type kwargs: dict |
| :return: BaseCompletion object. |
| :rtype: BaseCompletion |
| |
| """ |
| response = await async_chatcompletion_with_backoff( |
| |
| model=self.model_name, |
| messages=[ |
| {"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]} |
| ], |
| timeout=1000, |
| **kwargs |
| ) |
|
|
| return BaseCompletion(state="success", |
| content=response.choices[0].message["content"], |
| prompt_token=response.get("usage", {}).get("prompt_tokens", 0), |
| completion_token=response.get("usage", {}).get("completion_tokens", 0)) |
|
|
| def chat_completion(self, message: List[dict]) -> ChatCompletion: |
| """ |
| Chat completion method for OpenAI GPT API. |
| |
| :param message: The message to use for completion. |
| :type message: List[dict] |
| :return: ChatCompletion object. |
| :rtype: ChatCompletion |
| """ |
| try: |
| |
| |
| |
| |
| |
| response = openai.ChatCompletion.create( |
| n=self.params.n, |
| model=self.model_name, |
| messages=message, |
| temperature=self.params.temperature, |
| max_tokens=self.params.max_tokens, |
| top_p=self.params.top_p, |
| frequency_penalty=self.params.frequency_penalty, |
| presence_penalty=self.params.presence_penalty, |
| ) |
| return ChatCompletion( |
| state="success", |
| role=response.choices[0].message["role"], |
| content=response.choices[0].message["content"], |
| prompt_token=response.get("usage", {}).get("prompt_tokens", 0), |
| completion_token=response.get("usage", {}).get("completion_tokens", 0), |
| ) |
| except Exception as exception: |
| print("Exception:", exception) |
| return ChatCompletion(state="error", content=exception) |
|
|
| def stream_chat_completion(self, message: List[dict], **kwargs): |
| """ |
| Stream output chat completion for OpenAI GPT API. |
| |
| :param message: The message (scratchpad) to use for completion. Usually contains json of role and content. |
| :type message: List[dict] |
| :param kwargs: Additional keyword arguments. |
| :type kwargs: dict |
| :return: ChatCompletion object. |
| :rtype: ChatCompletion |
| """ |
| try: |
| |
| |
| |
| |
| |
| |
| response = openai.ChatCompletion.create( |
| n=self.params.n, |
| model=self.model_name, |
| messages=message, |
| temperature=self.params.temperature, |
| max_tokens=self.params.max_tokens, |
| top_p=self.params.top_p, |
| frequency_penalty=self.params.frequency_penalty, |
| presence_penalty=self.params.presence_penalty, |
| stream=True, |
| **kwargs |
| ) |
| role = next(response).choices[0].delta["role"] |
| messages = [] |
| |
| for resp in response: |
| messages.append(resp.choices[0].delta.get("content", "")) |
| yield ChatCompletion( |
| state="success", |
| role=role, |
| content=messages[-1], |
| prompt_token=0, |
| completion_token=0, |
| ) |
| except Exception as exception: |
| print("Exception:", exception) |
| return ChatCompletion(state="error", content=exception) |
|
|
| def function_chat_completion( |
| self, |
| message: List[dict], |
| function_map: Dict[str, Callable], |
| function_schema: List[Dict], |
| ) -> ChatCompletionWithHistory: |
| """ |
| Chat completion method for OpenAI GPT API. |
| |
| :param message: The message to use for completion. |
| :type message: List[dict] |
| :param function_map: The function map to use for completion. |
| :type function_map: Dict[str, Callable] |
| :param function_schema: The function schema to use for completion. |
| :type function_schema: List[Dict] |
| :return: ChatCompletionWithHistory object. |
| :rtype: ChatCompletionWithHistory |
| """ |
| assert len(function_schema) == len(function_map) |
| try: |
| |
| |
| |
| |
| |
| |
| response = openai.ChatCompletion.create( |
| n=self.params.n, |
| model=self.model_name, |
| messages=message, |
| functions=function_schema, |
| temperature=self.params.temperature, |
| max_tokens=self.params.max_tokens, |
| top_p=self.params.top_p, |
| frequency_penalty=self.params.frequency_penalty, |
| presence_penalty=self.params.presence_penalty, |
| ) |
| response_message = response.choices[0]["message"] |
|
|
| if response_message.get("function_call"): |
| function_name = response_message["function_call"]["name"] |
| fuction_to_call = function_map[function_name] |
| function_args = json.loads( |
| response_message["function_call"]["arguments"] |
| ) |
| function_response = fuction_to_call(**function_args) |
|
|
| |
| if isinstance(function_response, str): |
| plugin_cost = 0 |
| plugin_token = 0 |
| elif isinstance(function_response, AgentOutput): |
| plugin_cost = function_response.cost |
| plugin_token = function_response.token_usage |
| function_response = function_response.output |
| else: |
| raise Exception( |
| "Invalid tool response type. Must be on of [AgentOutput, str]" |
| ) |
|
|
| message.append(dict(response_message)) |
| message.append( |
| { |
| "role": "function", |
| "name": function_name, |
| "content": function_response, |
| } |
| ) |
| second_response = openai.ChatCompletion.create( |
| model=self.get_model_name(), |
| messages=message, |
| ) |
| message.append(dict(second_response.choices[0].message)) |
| return ChatCompletionWithHistory( |
| state="success", |
| role=second_response.choices[0].message["role"], |
| content=second_response.choices[0].message["content"], |
| prompt_token=response.get("usage", {}).get("prompt_tokens", 0) |
| + second_response.get("usage", {}).get("prompt_tokens", 0), |
| completion_token=response.get("usage", {}).get( |
| "completion_tokens", 0 |
| ) |
| + second_response.get("usage", {}).get("completion_tokens", 0), |
| message_scratchpad=message, |
| plugin_cost=plugin_cost, |
| plugin_token=plugin_token, |
| ) |
| else: |
| message.append(dict(response_message)) |
| return ChatCompletionWithHistory( |
| state="success", |
| role=response.choices[0].message["role"], |
| content=response.choices[0].message["content"], |
| prompt_token=response.get("usage", {}).get("prompt_tokens", 0), |
| completion_token=response.get("usage", {}).get( |
| "completion_tokens", 0 |
| ), |
| message_scratchpad=message, |
| ) |
|
|
| except Exception as exception: |
| print("Exception:", exception) |
| return ChatCompletionWithHistory(state="error", content=str(exception)) |
|
|
| def function_chat_stream_completion( |
| self, |
| message: List[dict], |
| function_map: Dict[str, Callable], |
| function_schema: List[Dict], |
| ) -> ChatCompletionWithHistory: |
| assert len(function_schema) == len(function_map) |
| try: |
| response = openai.ChatCompletion.create( |
| n=self.params.n, |
| model=self.get_model_name(), |
| messages=message, |
| functions=function_schema, |
| temperature=self.params.temperature, |
| max_tokens=self.params.max_tokens, |
| top_p=self.params.top_p, |
| frequency_penalty=self.params.frequency_penalty, |
| presence_penalty=self.params.presence_penalty, |
| stream=True, |
| ) |
| tmp = next(response) |
| role = tmp.choices[0].delta["role"] |
| _type = ( |
| "function_call" |
| if tmp.choices[0].delta["content"] is None |
| else "content" |
| ) |
| if _type == "function_call": |
| name = tmp.choices[0].delta["function_call"]["name"] |
| yield _type, ChatCompletionWithHistory( |
| state="success", |
| role=role, |
| content="{" + f'"name":"{name}", "arguments":', |
| message_scratchpad=message, |
| ) |
| for resp in response: |
| |
| content = resp.choices[0].delta.get(_type, "") |
| if isinstance(content, dict): |
| content = content["arguments"] |
| yield _type, ChatCompletionWithHistory( |
| state="success", |
| role=role, |
| content=content, |
| message_scratchpad=message, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| except Exception as e: |
| logger.error(f"Failed to get response {str(e)}", exc_info=True) |
| raise e |
|
|