File size: 3,621 Bytes
6c89611 16cc10f be09774 16cc10f e5f9801 be09774 16cc10f e5f9801 16cc10f e5f9801 6c89611 16cc10f e5f9801 6c89611 16cc10f 6c89611 16cc10f 6c89611 16cc10f e5f9801 16cc10f be09774 16cc10f 6c89611 16cc10f 6c89611 16cc10f 6c89611 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import copy
from enum import Enum
from os import PathLike
from openai import OpenAI
from commons.loggerfactory import LoggerFactory
from commons.utils import getdefault
from scripts.io import InputOutput
DEFAULT_WELCOME_MESSAGE = "How can I assist you ..."
DEFAULT_USER_PROMPT = "User: "
DEFAULT_AGENT_PROMPT = "Agent: "
DEFAULT_EXIT_MESSAGE = "Have a nice day!"
class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class ChatHistory:
def __init__(self, context_span: int, initial_context: list[dict[str, str]]):
self.__context_span = context_span
self.__context = list()
for context in initial_context:
self.__add_context(context)
def __ensure_max_size(self):
while len(self.__context) > self.__context_span:
self.__context.pop(1)
return self
def __add_context(self, context: dict[str, str]) -> "ChatHistory":
self.__context.append(context)
return self.__ensure_max_size()
def add_message(self, role: Role, content: str) -> "ChatHistory":
return self.__add_context({"role": role.value, "content": content})
def get_whole_context(self) -> list[dict[str, str]]:
return self.__ensure_max_size().__context.copy()
def get_chat_history(self) -> list[dict[str, str]]:
return self.get_whole_context()[0:]
def get_context_size(self) -> int:
return len(self.__ensure_max_size().__context)
def last_in_history(self):
return self.__context[-1]
def reset(self) -> "ChatHistory":
self.__context.clear()
return self
class OpenAIBot:
def __init__(self, bot: OpenAI, model: str, prompts: list[PathLike | str], context_span: int, **args):
self.logger = LoggerFactory.getLogger(self.__class__.__name__)
self.__bot = bot
self.__model = model
for prompt in prompts:
final_prompt = ""
with open(prompt, "r") as pf:
final_prompt += pf.read()
final_prompt = "You are Rick Sanchez from Rick and Morty."
self.__history_instance = ChatHistory(context_span=context_span,
initial_context=[{"role": Role.SYSTEM.value, "content": final_prompt}])
self.__history = self.get_history_copy()
self.__exit_codes: list = getdefault(args, "exit_codes", list())
def __is_exit(self, message: str) -> bool:
return message.lower() in self.__exit_codes
def respond(self, user_input: str, history: ChatHistory = None, append_user_input: bool = True) -> str | bool:
if not history:
history = self.__history
if self.__is_exit(user_input):
return False
response = None
if user_input:
messages = history.add_message(Role.USER, user_input).get_whole_context()
if not append_user_input:
messages = history.get_whole_context()
messages.append({"role": "user", "content": user_input})
chat = self.__bot.chat.completions.create(model=self.__model, messages=messages)
self.logger.info("Tokens count, prompts: %s, completion: %s, total: %s",
chat.usage.prompt_tokens, chat.usage.completion_tokens, chat.usage.total_tokens)
# del messages
reply = chat.choices[0].message.content
response = reply
history.add_message(Role.ASSISTANT, reply)
return response
def get_history_copy(self) -> ChatHistory:
return copy.deepcopy(self.__history_instance)
|