File size: 2,583 Bytes
b151e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

import gradio as gr

from typing import Dict, List, Union

from src.frontend import (
    continue_story_screen,
    story_information_widgets
)
from src.utils.init_openai import init_client

button_row = gr.Row(
    key='start_screen_button_row',
    preserved_by_key='key'
)
new_story_button = gr.Button(
    value='Begin a new adventure!',
    visible=True,
    interactive=True,
    key='new_story_button',
    preserved_by_key='key'
)
continue_story_button = gr.Button(
    value='Continue an existing adventure!',
    visible=True,
    interactive=True,
    key='continue_story_button',
    preserved_by_key='key'
)
api_key_textbox = gr.Text(
    type='password',
    label='OpenAI API Key',
    placeholder='Enter your OpenAI API key',
    interactive=True,
    visible=True,
    key='api_key_textbox',
    preserved_by_key='key'
)


def get_widgets() -> List[Union[gr.Text, gr.Button, gr.Row]]:
    return [
        button_row,
        new_story_button,
        continue_story_button,
        api_key_textbox
    ]


def get_wigets_updates(
    is_visible: bool = False
) -> Dict[Union[gr.Text, gr.Button, gr.Row], Union[gr.update, gr.Row]]:
    return {
        widget: gr.Row(visible=is_visible) if isinstance(widget, gr.Row) else \
            gr.update(visible=is_visible)
        for widget in get_widgets()
    }


def on_submit_new_story(
    api_key: str
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
    if init_client(api_key):
        return get_wigets_updates(False) | \
            story_information_widgets.get_widgets_updates(True)
    return get_wigets_updates(True) | \
        story_information_widgets.get_widgets_updates(False)


def on_submit_continue_story(
    api_key: str
) -> Dict[Union[gr.Text, gr.Slider, gr.Button], Union[gr.Row, gr.update]]:
    if init_client(api_key):
        return get_wigets_updates(False) | \
            continue_story_screen.get_widgets_updates(True)
    return get_wigets_updates(True) | \
        continue_story_screen.get_widgets_updates(False)


def render():
    api_key_textbox.render()
    button_row.render()
    with button_row:
        new_story_button.render()
        continue_story_button.render()

    new_story_button.click(
        fn=on_submit_new_story,
        inputs=[api_key_textbox],
        outputs=get_widgets() + \
            story_information_widgets.get_widgets()
    )
    continue_story_button.click(
        fn=on_submit_continue_story,
        inputs=[api_key_textbox],
        outputs=get_widgets() + \
            continue_story_screen.get_widgets()
    )