g3eIL's picture
Upload 80 files
77320e4 verified
import json
import logging
import os
from abc import ABC
from typing import Callable, List
import openai
from tenacity import ( # for exponential backoff
before_sleep_log,
retry,
stop_after_attempt,
wait_random_exponential,
)
from ..base_llm import BaseLLM
from ...schemas import *
logger = logging.getLogger(__name__)
MAX_PROMPT_LENGTH = 7000
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING))
def chatcompletion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
@retry(wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(10), reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING))
async def async_chatcompletion_with_backoff(**kwargs):
async def _internal_coroutine():
return await openai.ChatCompletion.acreate(**kwargs)
return await _internal_coroutine()
class OptOpenAIClient(BaseLLM, ABC):
"""
Wrapper class for OpenAI GPT API collections.
:param model_name: The name of the model to use.
:type model_name: str
:param params: The parameters for the model.
:type params: OptParamModel
"""
model_name: str
params: OptParamModel = OptParamModel()
def __init__(self, **data):
super().__init__(**data)
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
@classmethod
async def create(cls, config_data):
return OptOpenAIClient(**config_data)
def get_model_name(self) -> str:
return self.model_name
def get_model_param(self) -> OptParamModel:
return self.params
def completion(self, prompt: str, **kwargs) -> BaseCompletion:
"""
Completion method for OpenAI GPT API.
:param prompt: The prompt to use for completion.
:type prompt: str
:param kwargs: Additional keyword arguments.
:type kwargs: dict
:return: BaseCompletion object.
:rtype: BaseCompletion
"""
response = chatcompletion_with_backoff(
model=self.model_name,
# engine=self.get_model_name(), # GPT-4
messages=[
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
],
timeout=1000,
**kwargs
)
return BaseCompletion(state="success",
content=response.choices[0].message["content"],
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
completion_token=response.get("usage", {}).get("completion_tokens", 0))
async def async_completion(self, prompt: str, **kwargs) -> BaseCompletion:
"""
Completion method for OpenAI GPT API.
:param prompt: The prompt to use for completion.
:type prompt: str
:param kwargs: Additional keyword arguments.
:type kwargs: dict
:return: BaseCompletion object.
:rtype: BaseCompletion
"""
response = await async_chatcompletion_with_backoff(
# engine=self.get_model_name(), # GPT-4
model=self.model_name,
messages=[
{"role": "user", "content": prompt[-MAX_PROMPT_LENGTH:]}
],
timeout=1000,
**kwargs
)
return BaseCompletion(state="success",
content=response.choices[0].message["content"],
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
completion_token=response.get("usage", {}).get("completion_tokens", 0))
def chat_completion(self, message: List[dict]) -> ChatCompletion:
"""
Chat completion method for OpenAI GPT API.
:param message: The message to use for completion.
:type message: List[dict]
:return: ChatCompletion object.
:rtype: ChatCompletion
"""
try:
# response = openai.ChatCompletion.create(
# engine=self.get_model_name(), # GPT-4
# messages=message,
# timeout=1000,
# )
response = openai.ChatCompletion.create(
n=self.params.n,
model=self.model_name,
messages=message,
temperature=self.params.temperature,
max_tokens=self.params.max_tokens,
top_p=self.params.top_p,
frequency_penalty=self.params.frequency_penalty,
presence_penalty=self.params.presence_penalty,
)
return ChatCompletion(
state="success",
role=response.choices[0].message["role"],
content=response.choices[0].message["content"],
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
completion_token=response.get("usage", {}).get("completion_tokens", 0),
)
except Exception as exception:
print("Exception:", exception)
return ChatCompletion(state="error", content=exception)
def stream_chat_completion(self, message: List[dict], **kwargs):
"""
Stream output chat completion for OpenAI GPT API.
:param message: The message (scratchpad) to use for completion. Usually contains json of role and content.
:type message: List[dict]
:param kwargs: Additional keyword arguments.
:type kwargs: dict
:return: ChatCompletion object.
:rtype: ChatCompletion
"""
try:
# response = openai.ChatCompletion.create(
# engine=self.get_model_name(), # GPT-4
# messages=message,
# timeout=1000,
# **kwargs,
# )
response = openai.ChatCompletion.create(
n=self.params.n,
model=self.model_name,
messages=message,
temperature=self.params.temperature,
max_tokens=self.params.max_tokens,
top_p=self.params.top_p,
frequency_penalty=self.params.frequency_penalty,
presence_penalty=self.params.presence_penalty,
stream=True,
**kwargs
)
role = next(response).choices[0].delta["role"]
messages = []
## TODO: Calculate prompt_token and for stream mode
for resp in response:
messages.append(resp.choices[0].delta.get("content", ""))
yield ChatCompletion(
state="success",
role=role,
content=messages[-1],
prompt_token=0,
completion_token=0,
)
except Exception as exception:
print("Exception:", exception)
return ChatCompletion(state="error", content=exception)
def function_chat_completion(
self,
message: List[dict],
function_map: Dict[str, Callable],
function_schema: List[Dict],
) -> ChatCompletionWithHistory:
"""
Chat completion method for OpenAI GPT API.
:param message: The message to use for completion.
:type message: List[dict]
:param function_map: The function map to use for completion.
:type function_map: Dict[str, Callable]
:param function_schema: The function schema to use for completion.
:type function_schema: List[Dict]
:return: ChatCompletionWithHistory object.
:rtype: ChatCompletionWithHistory
"""
assert len(function_schema) == len(function_map)
try:
# response = openai.ChatCompletion.create(
# engine=self.get_model_name(), # GPT-4
# messages=message,
# functions=function_schema,
# timeout=1000,
# )
response = openai.ChatCompletion.create(
n=self.params.n,
model=self.model_name,
messages=message,
functions=function_schema,
temperature=self.params.temperature,
max_tokens=self.params.max_tokens,
top_p=self.params.top_p,
frequency_penalty=self.params.frequency_penalty,
presence_penalty=self.params.presence_penalty,
)
response_message = response.choices[0]["message"]
if response_message.get("function_call"):
function_name = response_message["function_call"]["name"]
fuction_to_call = function_map[function_name]
function_args = json.loads(
response_message["function_call"]["arguments"]
)
function_response = fuction_to_call(**function_args)
# Postprocess function response
if isinstance(function_response, str):
plugin_cost = 0
plugin_token = 0
elif isinstance(function_response, AgentOutput):
plugin_cost = function_response.cost
plugin_token = function_response.token_usage
function_response = function_response.output
else:
raise Exception(
"Invalid tool response type. Must be on of [AgentOutput, str]"
)
message.append(dict(response_message))
message.append(
{
"role": "function",
"name": function_name,
"content": function_response,
}
)
second_response = openai.ChatCompletion.create(
model=self.get_model_name(),
messages=message,
)
message.append(dict(second_response.choices[0].message))
return ChatCompletionWithHistory(
state="success",
role=second_response.choices[0].message["role"],
content=second_response.choices[0].message["content"],
prompt_token=response.get("usage", {}).get("prompt_tokens", 0)
+ second_response.get("usage", {}).get("prompt_tokens", 0),
completion_token=response.get("usage", {}).get(
"completion_tokens", 0
)
+ second_response.get("usage", {}).get("completion_tokens", 0),
message_scratchpad=message,
plugin_cost=plugin_cost,
plugin_token=plugin_token,
)
else:
message.append(dict(response_message))
return ChatCompletionWithHistory(
state="success",
role=response.choices[0].message["role"],
content=response.choices[0].message["content"],
prompt_token=response.get("usage", {}).get("prompt_tokens", 0),
completion_token=response.get("usage", {}).get(
"completion_tokens", 0
),
message_scratchpad=message,
)
except Exception as exception:
print("Exception:", exception)
return ChatCompletionWithHistory(state="error", content=str(exception))
def function_chat_stream_completion(
self,
message: List[dict],
function_map: Dict[str, Callable],
function_schema: List[Dict],
) -> ChatCompletionWithHistory:
assert len(function_schema) == len(function_map)
try:
response = openai.ChatCompletion.create(
n=self.params.n,
model=self.get_model_name(),
messages=message,
functions=function_schema,
temperature=self.params.temperature,
max_tokens=self.params.max_tokens,
top_p=self.params.top_p,
frequency_penalty=self.params.frequency_penalty,
presence_penalty=self.params.presence_penalty,
stream=True,
)
tmp = next(response)
role = tmp.choices[0].delta["role"]
_type = (
"function_call"
if tmp.choices[0].delta["content"] is None
else "content"
)
if _type == "function_call":
name = tmp.choices[0].delta["function_call"]["name"]
yield _type, ChatCompletionWithHistory(
state="success",
role=role,
content="{" + f'"name":"{name}", "arguments":',
message_scratchpad=message,
)
for resp in response:
# print(resp)
content = resp.choices[0].delta.get(_type, "")
if isinstance(content, dict):
content = content["arguments"]
yield _type, ChatCompletionWithHistory(
state="success",
role=role,
content=content,
message_scratchpad=message,
)
# result = ''.join(messages)
# if _type == "function_call":
# result = json.loads(result)
# function_name = result["name"]
# fuction_to_call = function_map[function_name]
# function_args = result["arguments"]
# function_response = fuction_to_call(**function_args)
#
# # Postprocess function response
# if isinstance(function_response, AgentOutput):
# function_response = function_response.output
# message.append({"role": "function",
# "name": function_name,
# "content": function_response})
# second_response = self.function_chat_stream_completion(message=message,function_map=function_map,function_schema=function_schema)
# message.append(dict(second_response.choices[0].message))
except Exception as e:
logger.error(f"Failed to get response {str(e)}", exc_info=True)
raise e