Spaces:
Sleeping
Sleeping
File size: 2,583 Bytes
b151e60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
from typing import Dict, List, Union
from src.frontend import (
continue_story_screen,
story_information_widgets
)
from src.utils.init_openai import init_client
button_row = gr.Row(
key='start_screen_button_row',
preserved_by_key='key'
)
new_story_button = gr.Button(
value='Begin a new adventure!',
visible=True,
interactive=True,
key='new_story_button',
preserved_by_key='key'
)
continue_story_button = gr.Button(
value='Continue an existing adventure!',
visible=True,
interactive=True,
key='continue_story_button',
preserved_by_key='key'
)
api_key_textbox = gr.Text(
type='password',
label='OpenAI API Key',
placeholder='Enter your OpenAI API key',
interactive=True,
visible=True,
key='api_key_textbox',
preserved_by_key='key'
)
def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
return [
button_row,
new_story_button,
continue_story_button,
api_key_textbox
]
def get_wigets_updates(
is_visible: bool = False
) -> Dict[Union[gr.Text, gr.Button, gr.Row], Union[gr.update, gr.Row]]:
return {
widget: gr.Row(visible=is_visible) if isinstance(widget, gr.Row) else \
gr.update(visible=is_visible)
for widget in get_widgets()
}
def on_submit_new_story(
api_key: str
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
if init_client(api_key):
return get_wigets_updates(False) | \
story_information_widgets.get_widgets_updates(True)
return get_wigets_updates(True) | \
story_information_widgets.get_widgets_updates(False)
def on_submit_continue_story(
api_key: str
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
if init_client(api_key):
return get_wigets_updates(False) | \
continue_story_screen.get_widgets_updates(True)
return get_wigets_updates(True) | \
continue_story_screen.get_widgets_updates(False)
def render():
api_key_textbox.render()
button_row.render()
with button_row:
new_story_button.render()
continue_story_button.render()
new_story_button.click(
fn=on_submit_new_story,
inputs=[api_key_textbox],
outputs=get_widgets() + \
story_information_widgets.get_widgets()
)
continue_story_button.click(
fn=on_submit_continue_story,
inputs=[api_key_textbox],
outputs=get_widgets() + \
continue_story_screen.get_widgets()
)
|