sto-rai / src /frontend /chatbot.py
yiiilonggg's picture
Fix loading json bugs
5ac7eb0
import gradio as gr
from langgraph.graph.graph import CompiledGraph
from typing import Dict, List, Union
import src.agents.coordinator as C
from src.dataclasses.process_state import ProcessState
graph: CompiledGraph = None
MAX_OPTIONS = 16
state: ProcessState = {
'question_num': 0,
'state_name': 'start',
'user_choice': '',
'previous_state_names': [],
'previous_questions': [],
'previous_user_choices': [],
'builder_instruction': '',
'foll_question_uuid': ''
}
state_history: List[ProcessState] = []
num_questions_: int = None
chatbot = gr.Chatbot(
type='messages',
key='chatbot',
preserved_by_key='key',
visible=False
)
option_buttons = [
gr.Button(
value='',
visible=False,
render=False,
key=f'option_button_{i}',
preserved_by_key='key'
)
for i in range(1, MAX_OPTIONS + 1)
]
restart_button = gr.Button(
'Explore other storylines!',
visible=False,
interactive=True,
key='restart_button',
preserved_by_key='key'
)
button_row = gr.Row(
key='button_row',
preserved_by_key='key'
)
def init_graph(
story_context: str,
categories_context: str,
num_questions: int,
num_options: int,
num_categories: int
) -> None:
global graph, option_buttons, num_questions_
num_questions_ = num_questions
graph = C.init_graph(
story_context,
categories_context,
num_questions,
num_options,
num_categories
)
for i in range(num_options, MAX_OPTIONS):
option_buttons[i].unrender()
option_buttons = option_buttons[:num_options]
def on_user_response(
user_choice: str,
history: List[Dict[str, str]]
) -> Dict[Union[gr.Button, gr.Chatbot], Union[List[Dict[str, str]], gr.update]]:
state['user_choice'] = user_choice.replace(' ', '_')
user_message = [{'role': 'user', 'content': user_choice}]
updated_history = history + user_message
return {chatbot: updated_history} | {
button: gr.update(visible=False)
for button in option_buttons
}
def control_screen_widgets() -> List[Union[gr.Chatbot, gr.Row, gr.Button]]:
return [chatbot, button_row, restart_button]
def control_screen(
is_visible: bool
) -> Dict[Union[gr.Chatbot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
return {
chatbot: gr.update(visible=is_visible),
button_row: gr.Row(visible=is_visible),
restart_button: gr.update(visible=False)
}
def on_chatbot_response(
history: List[Dict[str, str]]
) -> Dict[Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]]:
global state
state_history.append({key: val for key, val in state.items()})
question_num = state['question_num']
if question_num < num_questions_:
state = graph.invoke(state)
question_num = state['question_num']
state_name = state['state_name']
scenario = C.states[question_num][state_name]
question = C.questions[state['foll_question_uuid']]
question_str = question.question.replace('_', ' ').strip('\"').strip('\'')
text_to_user = scenario + '\n' + question_str
bot_message = [{'role': 'assistant', 'content': text_to_user}]
updated_history = history + bot_message
options = [
option.replace('_', ' ').strip('\"').strip('\'')
for option in question.options.keys()
]
button_updates = {
button: gr.update(value=option, visible=True)
for option, button in zip(options, option_buttons)
}
button_updates[restart_button] = gr.update(visible=False)
else:
selected_category, reason = C.evaluation(state)
description, traits = None, None
for category in C.categories:
if category.name != selected_category:
continue
description = category.description
traits = category.traits
C.categories_seen[selected_category] = True
traits_string = traits
bot_messages = [
{'role': 'assistant', 'content': f'We think that you are closest to **{selected_category}**!'},
{'role': 'assistant', 'content': f'## {selected_category}\n\n{description}\n\n{traits_string}'},
{'role': 'assistant', 'content': reason}
]
updated_history = history + bot_messages
button_updates = {
button: gr.update(visible=False)
for button in option_buttons
}
button_updates[restart_button] = gr.update(visible=True)
C.save_coordinator()
return {chatbot: gr.update(value=updated_history, visible=True)} | button_updates
def on_restart_button_click()-> Dict[
Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]
]:
global state
state = {
'question_num': 0,
'state_name': 'start',
'user_choice': '',
'previous_state_names': [],
'previous_questions': [],
'previous_user_choices': [],
'builder_instruction': '',
'foll_question_uuid': ''
}
return {restart_button: gr.update(visible=False)} | \
on_chatbot_response([])
def render():
chatbot.render()
button_row.render()
with button_row:
for button in option_buttons:
button.unrender()
button.render()
button.click(
fn=on_user_response,
inputs=[button, chatbot],
outputs=option_buttons + [chatbot]
).then(
fn=on_chatbot_response,
inputs=[chatbot],
outputs=option_buttons + [chatbot, restart_button]
)
restart_button.render()
restart_button.click(
fn=on_restart_button_click,
inputs=None,
outputs=option_buttons + [chatbot, restart_button]
)