Spaces:
Sleeping
Sleeping
File size: 5,919 Bytes
b151e60 5ac7eb0 b151e60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import gradio as gr
from langgraph.graph.graph import CompiledGraph
from typing import Dict, List, Union
import src.agents.coordinator as C
from src.dataclasses.process_state import ProcessState
graph: CompiledGraph = None
MAX_OPTIONS = 16
state: ProcessState = {
'question_num': 0,
'state_name': 'start',
'user_choice': '',
'previous_state_names': [],
'previous_questions': [],
'previous_user_choices': [],
'builder_instruction': '',
'foll_question_uuid': ''
}
state_history: List[ProcessState] = []
num_questions_: int = None
chatbot = gr.Chatbot(
type='messages',
key='chatbot',
preserved_by_key='key',
visible=False
)
option_buttons = [
gr.Button(
value='',
visible=False,
render=False,
key=f'option_button_{i}',
preserved_by_key='key'
)
for i in range(1, MAX_OPTIONS + 1)
]
restart_button = gr.Button(
'Explore other storylines!',
visible=False,
interactive=True,
key='restart_button',
preserved_by_key='key'
)
button_row = gr.Row(
key='button_row',
preserved_by_key='key'
)
def init_graph(
story_context: str,
categories_context: str,
num_questions: int,
num_options: int,
num_categories: int
) -> None:
global graph, option_buttons, num_questions_
num_questions_ = num_questions
graph = C.init_graph(
story_context,
categories_context,
num_questions,
num_options,
num_categories
)
for i in range(num_options, MAX_OPTIONS):
option_buttons[i].unrender()
option_buttons = option_buttons[:num_options]
def on_user_response(
user_choice: str,
history: List[Dict[str, str]]
) -> Dict[Union[gr.Button, gr.Chatbot], Union[List[Dict[str, str]], gr.update]]:
state['user_choice'] = user_choice.replace(' ', '_')
user_message = [{'role': 'user', 'content': user_choice}]
updated_history = history + user_message
return {chatbot: updated_history} | {
button: gr.update(visible=False)
for button in option_buttons
}
def control_screen_widgets() -> List[Union[gr.Chatbot, gr.Row, gr.Button]]:
return [chatbot, button_row, restart_button]
def control_screen(
is_visible: bool
) -> Dict[Union[gr.Chatbot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
return {
chatbot: gr.update(visible=is_visible),
button_row: gr.Row(visible=is_visible),
restart_button: gr.update(visible=False)
}
def on_chatbot_response(
history: List[Dict[str, str]]
) -> Dict[Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]]:
global state
state_history.append({key: val for key, val in state.items()})
question_num = state['question_num']
if question_num < num_questions_:
state = graph.invoke(state)
question_num = state['question_num']
state_name = state['state_name']
scenario = C.states[question_num][state_name]
question = C.questions[state['foll_question_uuid']]
question_str = question.question.replace('_', ' ').strip('\"').strip('\'')
text_to_user = scenario + '\n' + question_str
bot_message = [{'role': 'assistant', 'content': text_to_user}]
updated_history = history + bot_message
options = [
option.replace('_', ' ').strip('\"').strip('\'')
for option in question.options.keys()
]
button_updates = {
button: gr.update(value=option, visible=True)
for option, button in zip(options, option_buttons)
}
button_updates[restart_button] = gr.update(visible=False)
else:
selected_category, reason = C.evaluation(state)
description, traits = None, None
for category in C.categories:
if category.name != selected_category:
continue
description = category.description
traits = category.traits
C.categories_seen[selected_category] = True
traits_string = traits
bot_messages = [
{'role': 'assistant', 'content': f'We think that you are closest to **{selected_category}**!'},
{'role': 'assistant', 'content': f'## {selected_category}\n\n{description}\n\n{traits_string}'},
{'role': 'assistant', 'content': reason}
]
updated_history = history + bot_messages
button_updates = {
button: gr.update(visible=False)
for button in option_buttons
}
button_updates[restart_button] = gr.update(visible=True)
C.save_coordinator()
return {chatbot: gr.update(value=updated_history, visible=True)} | button_updates
def on_restart_button_click()-> Dict[
Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]
]:
global state
state = {
'question_num': 0,
'state_name': 'start',
'user_choice': '',
'previous_state_names': [],
'previous_questions': [],
'previous_user_choices': [],
'builder_instruction': '',
'foll_question_uuid': ''
}
return {restart_button: gr.update(visible=False)} | \
on_chatbot_response([])
def render():
chatbot.render()
button_row.render()
with button_row:
for button in option_buttons:
button.unrender()
button.render()
button.click(
fn=on_user_response,
inputs=[button, chatbot],
outputs=option_buttons + [chatbot]
).then(
fn=on_chatbot_response,
inputs=[chatbot],
outputs=option_buttons + [chatbot, restart_button]
)
restart_button.render()
restart_button.click(
fn=on_restart_button_click,
inputs=None,
outputs=option_buttons + [chatbot, restart_button]
)
|