File size: 3,621 Bytes
6c89611
16cc10f
 
 
 
 
be09774
16cc10f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5f9801
be09774
16cc10f
 
e5f9801
 
16cc10f
e5f9801
 
6c89611
 
 
16cc10f
 
 
 
 
e5f9801
6c89611
 
16cc10f
 
 
 
6c89611
16cc10f
6c89611
16cc10f
e5f9801
 
 
 
16cc10f
be09774
 
16cc10f
 
 
6c89611
 
16cc10f
6c89611
16cc10f
6c89611
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import copy
from enum import Enum
from os import PathLike

from openai import OpenAI

from commons.loggerfactory import LoggerFactory
from commons.utils import getdefault
from scripts.io import InputOutput

DEFAULT_WELCOME_MESSAGE = "How can I assist you ..."
DEFAULT_USER_PROMPT = "User: "
DEFAULT_AGENT_PROMPT = "Agent: "
DEFAULT_EXIT_MESSAGE = "Have a nice day!"


class Role(Enum):
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"


class ChatHistory:
    def __init__(self, context_span: int, initial_context: list[dict[str, str]]):
        self.__context_span = context_span
        self.__context = list()
        for context in initial_context:
            self.__add_context(context)

    def __ensure_max_size(self):
        while len(self.__context) > self.__context_span:
            self.__context.pop(1)
        return self

    def __add_context(self, context: dict[str, str]) -> "ChatHistory":
        self.__context.append(context)
        return self.__ensure_max_size()

    def add_message(self, role: Role, content: str) -> "ChatHistory":
        return self.__add_context({"role": role.value, "content": content})

    def get_whole_context(self) -> list[dict[str, str]]:
        return self.__ensure_max_size().__context.copy()

    def get_chat_history(self) -> list[dict[str, str]]:
        return self.get_whole_context()[0:]

    def get_context_size(self) -> int:
        return len(self.__ensure_max_size().__context)

    def last_in_history(self):
        return self.__context[-1]

    def reset(self) -> "ChatHistory":
        self.__context.clear()
        return self


class OpenAIBot:
    def __init__(self, bot: OpenAI, model: str, prompts: list[PathLike | str], context_span: int, **args):
        self.logger = LoggerFactory.getLogger(self.__class__.__name__)
        self.__bot = bot
        self.__model = model
        for prompt in prompts:
            final_prompt = ""
            with open(prompt, "r") as pf:
                final_prompt += pf.read()
        final_prompt = "You are Rick Sanchez from Rick and Morty."
        self.__history_instance = ChatHistory(context_span=context_span,
                                              initial_context=[{"role": Role.SYSTEM.value, "content": final_prompt}])
        self.__history = self.get_history_copy()
        self.__exit_codes: list = getdefault(args, "exit_codes", list())

    def __is_exit(self, message: str) -> bool:
        return message.lower() in self.__exit_codes

    def respond(self, user_input: str, history: ChatHistory = None, append_user_input: bool = True) -> str | bool:
        if not history:
            history = self.__history

        if self.__is_exit(user_input):
            return False

        response = None
        if user_input:
            messages = history.add_message(Role.USER, user_input).get_whole_context()

            if not append_user_input:
                messages = history.get_whole_context()
                messages.append({"role": "user", "content": user_input})

            chat = self.__bot.chat.completions.create(model=self.__model, messages=messages)
            self.logger.info("Tokens count, prompts: %s, completion: %s, total: %s",
                             chat.usage.prompt_tokens, chat.usage.completion_tokens, chat.usage.total_tokens)
            # del messages

            reply = chat.choices[0].message.content
            response = reply
            history.add_message(Role.ASSISTANT, reply)

        return response

    def get_history_copy(self) -> ChatHistory:
        return copy.deepcopy(self.__history_instance)