File size: 3,317 Bytes
6c89611
 
 
 
be09774
6c89611
 
 
7b9bd20
 
 
 
 
6c89611
 
be09774
6c89611
 
 
 
 
be09774
e5f9801
6c89611
e5f9801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca12f66
e5f9801
 
 
 
 
 
6c89611
 
 
 
 
2dc92ec
be09774
6c89611
 
 
 
e5f9801
7b9bd20
e5f9801
 
 
7b9bd20
e5f9801
 
6c89611
 
 
 
e5f9801
6c89611
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from pathlib import Path

import gradio as gr

from commons.loggerfactory import LoggerFactory
from scripts.components import OpenAIBot, DEFAULT_WELCOME_MESSAGE, DEFAULT_EXIT_MESSAGE, ChatHistory


GREETING_PROMPT = ("You con imagine a story hook and imagine the state of Rick.\n"
                   "You see Morty, Greet him in few words!")
EXIT_PROMPT = "Morty is leaving. Give an apt response. Keep it short"


class GradioBotUI:
    def __init__(self, bot: OpenAIBot, user_avatar: str | Path, agent_avatar: str | Path):
        self.logger = LoggerFactory.getLogger(self.__class__.__name__)
        self.bot = bot
        self.user_avatar = user_avatar
        self.agent_avatar = agent_avatar

    def launch(self):
        self.logger.info("Launching gradio ui ...")
        with gr.Blocks(title="Rick and Morty!", fill_height=True) as demo:
            with gr.Row():
                gr.Image(value="media/rnm_title.png", height="100px", container=False, show_download_button=False,
                         show_fullscreen_button=False)

            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        gr.HTML(value="<center><h1>Adventure generator</h1></center>")
                    with gr.Row():
                        gr.Image(value="media/rnm_poster.png", container=False, show_download_button=False,
                                 show_fullscreen_button=False, height="80%")

                with gr.Column():
                    history = self.bot.get_history_copy()
                    greetings = self.greeting(history)
                    with gr.Row():
                        chatbot = gr.Chatbot(bubble_full_width=False, show_copy_button=True,
                                             avatar_images=(self.user_avatar, self.agent_avatar), scale=1,
                                             value=[[None, greetings]], show_copy_all_button=True)

                    with gr.Row():
                        prompt = gr.Textbox(label="Prompt")
                    state = gr.State(value=history)

            user_step = prompt.submit(self.user_message, inputs=[prompt, chatbot],
                                      outputs=[prompt, chatbot], queue=False)
            user_step.then(self.agent_message, inputs=[state, chatbot], outputs=[prompt, state, chatbot])

        demo.launch(server_port=8080, server_name="0.0.0.0")
        self.logger.info("Done launching gradio ui!")

    def user_message(self, user_text, history):
        return gr.Textbox(label="Prompt"), history + [[user_text, None]]

    def greeting(self, chat_history: ChatHistory):
        response = self.bot.respond(GREETING_PROMPT, chat_history, append_user_input=True)
        return response

    def __bye(self, chat_history: ChatHistory):
        response = self.bot.respond(EXIT_PROMPT, chat_history, append_user_input=False)
        return response

    def agent_message(self, chat_history: ChatHistory, history):
        user_text = history[-1][0]
        response = self.bot.respond(user_text, chat_history)
        if not response:
            history[-1][1] = self.__bye(chat_history)
            return gr.Textbox(label="Prompt", interactive=False), chat_history, history
        history[-1][1] = response
        return "", chat_history, history