Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| from typing import Dict, List, Union | |
| import src.agents.coordinator as C | |
| from src.frontend.chatbot import ( | |
| MAX_OPTIONS, | |
| chatbot, | |
| init_graph, | |
| on_chatbot_response, | |
| option_buttons, | |
| restart_button | |
| ) | |
| from src.frontend import sidebar | |
| from src.utils.utils import ( | |
| get_session_dir, | |
| save_information, | |
| transform_story_name | |
| ) | |
| story_name_textbox = gr.Text( | |
| placeholder='Give your adventure a unique name!', | |
| label='Story Name', | |
| interactive=True, | |
| key='story_name_textbox', | |
| preserved_by_key='key', | |
| visible=False | |
| ) | |
| story_context_textbox = gr.Text( | |
| placeholder='What kind of story would you like to adventure on?', | |
| label='Story Context', | |
| interactive=True, | |
| key='story_context_textbox', | |
| preserved_by_key='key', | |
| visible=False | |
| ) | |
| categories_context_textbox = gr.Text( | |
| placeholder='What is the theme of the categories you want to have?', | |
| label='Categories Context', | |
| interactive=True, | |
| key='categories_context_textbox', | |
| preserved_by_key='key', | |
| visible=False | |
| ) | |
| num_questions_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label='How many questions would you want your journey to have?', | |
| interactive=True, | |
| visible=False, | |
| key='num_questions_slider', | |
| preserved_by_key='key' | |
| ) | |
| num_options_slider = gr.Slider( | |
| minimum=2, | |
| maximum=4, | |
| value=2, | |
| step=1, | |
| label='How many options would you like at each turn?', | |
| interactive=True, | |
| visible=False, | |
| key='num_options_slider', | |
| preserved_by_key='key' | |
| ) | |
| num_categories_slider = gr.Slider( | |
| minimum=2, | |
| maximum=MAX_OPTIONS, | |
| value=2, | |
| step=1, | |
| label='How many categories / endings would you like to have?', | |
| interactive=True, | |
| visible=False, | |
| key='num_categories_slider', | |
| preserved_by_key='key' | |
| ) | |
| story_information_submit_button = gr.Button( | |
| visible=False, | |
| interactive=False, | |
| key='story_information_submit_button', | |
| preserved_by_key='key' | |
| ) | |
| def get_widgets() -> List[Union[gr.Text, gr.Slider, gr.Button]]: | |
| return [ | |
| story_name_textbox, | |
| story_context_textbox, | |
| categories_context_textbox, | |
| num_questions_slider, | |
| num_options_slider, | |
| num_categories_slider, | |
| story_information_submit_button | |
| ] | |
| def get_widgets_updates( | |
| is_visible: bool | |
| ) -> Dict[Union[gr.Slider, gr.Text, gr.Button], gr.update]: | |
| return { | |
| widget: gr.update(visible=is_visible) | |
| for widget in get_widgets() | |
| } | |
| def check_story_name(story_name: str) -> bool: | |
| story_name_dir = transform_story_name(story_name) | |
| return not story_name_dir in os.listdir(get_session_dir()) | |
| def on_text_change( | |
| story_name: str, | |
| story_context: str, | |
| categories_context: str | |
| ) -> gr.update: | |
| if all([story_name, story_context, categories_context]) and \ | |
| check_story_name(story_name): | |
| return gr.update(interactive=True) | |
| return gr.update(interactive=False) | |
| def save_story_information( | |
| story_name: str, | |
| story_context: str, | |
| categories_context: str, | |
| num_questions: int, | |
| num_options: int, | |
| num_categories: int | |
| ): | |
| story_name_dir = transform_story_name(story_name) | |
| story_dirpath = os.path.join(get_session_dir(), story_name_dir) | |
| os.mkdir(story_dirpath) | |
| story_information_filepath = os.path.join(story_dirpath, 'story.json') | |
| story_information = { | |
| 'story_name': story_name, | |
| 'story_context': story_context, | |
| 'categories_context': categories_context, | |
| 'num_questions': num_questions, | |
| 'num_options': num_options, | |
| 'num_categories': num_categories | |
| } | |
| save_information(story_information, story_information_filepath) | |
| def on_submit( | |
| story_name: str, | |
| story_context: str, | |
| categories_context: str, | |
| num_questions: int, | |
| num_options: int, | |
| num_categories: int | |
| ) -> Dict[Union[gr.Text, gr.Slider, gr.Button, gr.Chatbot], gr.update]: | |
| save_story_information( | |
| story_name, | |
| story_context, | |
| categories_context, | |
| num_questions, | |
| num_options, | |
| num_categories | |
| ) | |
| C.story_name = transform_story_name(story_name) | |
| init_graph( | |
| story_context, | |
| categories_context, | |
| num_questions, | |
| num_options, | |
| num_categories | |
| ) | |
| chatbot_updates = on_chatbot_response([]) | |
| return chatbot_updates | get_widgets_updates(False) | sidebar.view_screen() | |
| def render(): | |
| for widget in get_widgets(): | |
| widget.render() | |
| if not isinstance(widget, gr.Text): | |
| continue | |
| widget.change( | |
| fn=on_text_change, | |
| inputs=[ | |
| story_name_textbox, | |
| story_context_textbox, | |
| categories_context_textbox | |
| ], | |
| outputs=[story_information_submit_button] | |
| ) | |
| story_information_submit_button.click( | |
| fn=on_submit, | |
| inputs=[ | |
| widget | |
| for widget in get_widgets() | |
| if not isinstance(widget, gr.Button) | |
| ], | |
| outputs=get_widgets() + \ | |
| [chatbot, restart_button] + \ | |
| option_buttons + \ | |
| sidebar.get_widgets() | |
| ) | |