sto-rai / src /frontend /continue_story_screen.py
yiiilonggg's picture
Fix loading json bugs
5ac7eb0
import gradio as gr
import os
from typing import Dict, List, Tuple, Union
import src.agents.coordinator as C
from src.frontend.chatbot import (
chatbot,
init_graph,
on_chatbot_response,
option_buttons,
restart_button
)
from src.frontend import sidebar
from src.utils.utils import load_information, transform_story_name
information_text = gr.Text(
value='Select which story to continue below!',
interactive=False,
key='continue_story_user_information',
preserved_by_key='key',
visible=False,
label=''
)
session_dir = os.path.join(
os.path.dirname(__file__),
'..', '..', 'sessions'
)
story_widgets: List[Tuple[gr.Row, gr.Button, gr.Text]] = []
story_dir_mapper = dict()
for i, dirname in enumerate(sorted(os.listdir(session_dir))):
if dirname.startswith('.'):
continue
dirpath = os.path.join(session_dir, dirname)
if not os.listdir(dirpath):
continue
story_information_filepath = os.path.join(dirpath, 'story.json')
story_information_dict = load_information(story_information_filepath)
story_name = story_information_dict['story_name']
story_context = story_information_dict['story_context']
row = gr.Row(
visible=False,
key=f'continue_row_{i}',
preserved_by_key='key'
)
button = gr.Button(
value=story_name,
visible=False,
key=f'continue_button_{i}',
preserved_by_key='key',
scale=1
)
text = gr.Text(
value=story_context,
label='',
interactive=False,
visible=False,
key=f'continue_story_{i}',
preserved_by_key='key',
scale=3
)
story_widgets.append((row, button, text))
story_dir_mapper[story_name] = dirpath
def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
widgets = [information_text]
for row, button, text in story_widgets:
widgets.append(row)
widgets.append(button)
widgets.append(text)
return widgets
def get_widgets_updates(
is_visible: bool
) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]:
updates = dict()
for widget in get_widgets():
if isinstance(widget, gr.Row):
updates[widget] = gr.Row(visible=is_visible)
else:
updates[widget] = gr.update(visible=is_visible)
return updates
def on_button_click(
story_name: str
) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]:
dirpath = story_dir_mapper[story_name]
story_name_ = transform_story_name(story_name)
story_information_filepath = os.path.join(dirpath, 'story.json')
story_information_dict = load_information(story_information_filepath)
story_context = story_information_dict['story_context']
categories_context = story_information_dict['categories_context']
num_questions = story_information_dict['num_questions']
num_options = story_information_dict['num_options']
num_categories = story_information_dict['num_categories']
init_graph(
story_context,
categories_context,
num_questions,
num_options,
num_categories
)
C.story_name = story_name_
C.load_coordinator()
chatbot_updates = on_chatbot_response([])
return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen()
def render():
information_text.render()
for row, button, text in story_widgets:
row.render()
with row:
button.render()
text.render()
button.click(
fn=on_button_click,
inputs=[button],
outputs=get_widgets() + \
[chatbot, restart_button] + \
option_buttons + \
sidebar.get_widgets()
)