Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langgraph.graph.graph import CompiledGraph | |
| from typing import Dict, List, Union | |
| import src.agents.coordinator as C | |
| from src.dataclasses.process_state import ProcessState | |
| graph: CompiledGraph = None | |
| MAX_OPTIONS = 16 | |
| state: ProcessState = { | |
| 'question_num': 0, | |
| 'state_name': 'start', | |
| 'user_choice': '', | |
| 'previous_state_names': [], | |
| 'previous_questions': [], | |
| 'previous_user_choices': [], | |
| 'builder_instruction': '', | |
| 'foll_question_uuid': '' | |
| } | |
| state_history: List[ProcessState] = [] | |
| num_questions_: int = None | |
| chatbot = gr.Chatbot( | |
| type='messages', | |
| key='chatbot', | |
| preserved_by_key='key', | |
| visible=False | |
| ) | |
| option_buttons = [ | |
| gr.Button( | |
| value='', | |
| visible=False, | |
| render=False, | |
| key=f'option_button_{i}', | |
| preserved_by_key='key' | |
| ) | |
| for i in range(1, MAX_OPTIONS + 1) | |
| ] | |
| restart_button = gr.Button( | |
| 'Explore other storylines!', | |
| visible=False, | |
| interactive=True, | |
| key='restart_button', | |
| preserved_by_key='key' | |
| ) | |
| button_row = gr.Row( | |
| key='button_row', | |
| preserved_by_key='key' | |
| ) | |
| def init_graph( | |
| story_context: str, | |
| categories_context: str, | |
| num_questions: int, | |
| num_options: int, | |
| num_categories: int | |
| ) -> None: | |
| global graph, option_buttons, num_questions_ | |
| num_questions_ = num_questions | |
| graph = C.init_graph( | |
| story_context, | |
| categories_context, | |
| num_questions, | |
| num_options, | |
| num_categories | |
| ) | |
| for i in range(num_options, MAX_OPTIONS): | |
| option_buttons[i].unrender() | |
| option_buttons = option_buttons[:num_options] | |
| def on_user_response( | |
| user_choice: str, | |
| history: List[Dict[str, str]] | |
| ) -> Dict[Union[gr.Button, gr.Chatbot], Union[List[Dict[str, str]], gr.update]]: | |
| state['user_choice'] = user_choice.replace(' ', '_') | |
| user_message = [{'role': 'user', 'content': user_choice}] | |
| updated_history = history + user_message | |
| return {chatbot: updated_history} | { | |
| button: gr.update(visible=False) | |
| for button in option_buttons | |
| } | |
| def control_screen_widgets() -> List[Union[gr.Chatbot, gr.Row, gr.Button]]: | |
| return [chatbot, button_row, restart_button] | |
| def control_screen( | |
| is_visible: bool | |
| ) -> Dict[Union[gr.Chatbot, gr.Row, gr.Button], Union[gr.update, gr.Row]]: | |
| return { | |
| chatbot: gr.update(visible=is_visible), | |
| button_row: gr.Row(visible=is_visible), | |
| restart_button: gr.update(visible=False) | |
| } | |
| def on_chatbot_response( | |
| history: List[Dict[str, str]] | |
| ) -> Dict[Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update]]: | |
| global state | |
| state_history.append({key: val for key, val in state.items()}) | |
| question_num = state['question_num'] | |
| if question_num < num_questions_: | |
| state = graph.invoke(state) | |
| question_num = state['question_num'] | |
| state_name = state['state_name'] | |
| scenario = C.states[question_num][state_name] | |
| question = C.questions[state['foll_question_uuid']] | |
| question_str = question.question.replace('_', ' ').strip('\"').strip('\'') | |
| text_to_user = scenario + '\n' + question_str | |
| bot_message = [{'role': 'assistant', 'content': text_to_user}] | |
| updated_history = history + bot_message | |
| options = [ | |
| option.replace('_', ' ').strip('\"').strip('\'') | |
| for option in question.options.keys() | |
| ] | |
| button_updates = { | |
| button: gr.update(value=option, visible=True) | |
| for option, button in zip(options, option_buttons) | |
| } | |
| button_updates[restart_button] = gr.update(visible=False) | |
| else: | |
| selected_category, reason = C.evaluation(state) | |
| description, traits = None, None | |
| for category in C.categories: | |
| if category.name != selected_category: | |
| continue | |
| description = category.description | |
| traits = category.traits | |
| C.categories_seen[selected_category] = True | |
| traits_string = traits | |
| bot_messages = [ | |
| {'role': 'assistant', 'content': f'We think that you are closest to **{selected_category}**!'}, | |
| {'role': 'assistant', 'content': f'## {selected_category}\n\n{description}\n\n{traits_string}'}, | |
| {'role': 'assistant', 'content': reason} | |
| ] | |
| updated_history = history + bot_messages | |
| button_updates = { | |
| button: gr.update(visible=False) | |
| for button in option_buttons | |
| } | |
| button_updates[restart_button] = gr.update(visible=True) | |
| C.save_coordinator() | |
| return {chatbot: gr.update(value=updated_history, visible=True)} | button_updates | |
| def on_restart_button_click()-> Dict[ | |
| Union[gr.Chatbot, gr.Button], Union[List[Dict[str, str]], gr.update] | |
| ]: | |
| global state | |
| state = { | |
| 'question_num': 0, | |
| 'state_name': 'start', | |
| 'user_choice': '', | |
| 'previous_state_names': [], | |
| 'previous_questions': [], | |
| 'previous_user_choices': [], | |
| 'builder_instruction': '', | |
| 'foll_question_uuid': '' | |
| } | |
| return {restart_button: gr.update(visible=False)} | \ | |
| on_chatbot_response([]) | |
| def render(): | |
| chatbot.render() | |
| button_row.render() | |
| with button_row: | |
| for button in option_buttons: | |
| button.unrender() | |
| button.render() | |
| button.click( | |
| fn=on_user_response, | |
| inputs=[button, chatbot], | |
| outputs=option_buttons + [chatbot] | |
| ).then( | |
| fn=on_chatbot_response, | |
| inputs=[chatbot], | |
| outputs=option_buttons + [chatbot, restart_button] | |
| ) | |
| restart_button.render() | |
| restart_button.click( | |
| fn=on_restart_button_click, | |
| inputs=None, | |
| outputs=option_buttons + [chatbot, restart_button] | |
| ) | |