Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from igraph import Graph, EdgeSeq | |
| from typing import Dict, List, Union | |
| import src.agents.coordinator as C | |
| categories_row = gr.Row( | |
| visible=False, | |
| key='categories_row', | |
| preserved_by_key='key' | |
| ) | |
| graph_row = gr.Row( | |
| visible=False, | |
| key='graph_row', | |
| preserved_by_key='key' | |
| ) | |
| category_buttons = [ | |
| gr.Button( | |
| value='', | |
| visible=False, | |
| key=f'category_button_{i + 1}', | |
| preserved_by_key='key' | |
| ) | |
| for i in range(16) | |
| ] | |
| category_text = gr.Text( | |
| value='Select a category!', | |
| label='', | |
| visible=False, | |
| interactive=False, | |
| key='category_text', | |
| preserved_by_key='key' | |
| ) | |
| graph_plot = gr.Plot( | |
| visible=False, | |
| key='graph_plot', | |
| preserved_by_key='key' | |
| ) | |
| def control_screen(is_visible: bool) -> Dict[gr.Row, gr.Row]: | |
| return { | |
| categories_row: gr.Row(visible=is_visible), | |
| graph_row: gr.Row(visible=is_visible), | |
| category_text: gr.update(visible=is_visible) | |
| } | |
| def update_categories() -> Dict[Union[gr.Button, gr.Text], gr.update]: | |
| button_updates = { | |
| button: gr.update( | |
| value=category.name if C.categories_seen[category.name] \ | |
| else '???', | |
| visible=True, | |
| interactive=C.categories_seen[category.name], | |
| variant='primary' if C.categories_seen[category.name] \ | |
| else 'secondary' | |
| ) | |
| for button, category in zip(category_buttons, C.categories) | |
| } | |
| return { | |
| category_text: gr.update(visible=True, value='Select a category!') | |
| } | button_updates | |
| def on_category_click(category_name: str) -> str: | |
| for category in C.categories: | |
| if category.name != category_name: | |
| continue | |
| return category.name + '\n\n' + \ | |
| category.description + '\n\n' + \ | |
| category.traits | |
| def update_graph(): | |
| n_vertices = sum([ | |
| len(states_) | |
| for level, states_ in C.states.items() | |
| if level <= C.num_questions_ | |
| ]) | |
| graph = Graph(directed=True) | |
| nodes = [] | |
| node_attributes = {'description': []} | |
| for level in range(1, C.num_questions_ + 1): | |
| for name, description in C.states[level].items(): | |
| node_name = f'{name} (Stage {level})' | |
| nodes.append(node_name) | |
| node_attributes['description'].append(description) | |
| edges = [] | |
| edge_attributes = { | |
| 'question': [], | |
| 'option': [] | |
| } | |
| for level in range(1, C.num_questions_): | |
| for state, question_uuid in C.state_question_map[level].items(): | |
| question = C.questions[question_uuid] | |
| question_str = question.question | |
| options = question.options | |
| prev_foll_state = None | |
| for option, foll_state in sorted(options.items(), key=lambda x: x[1]): | |
| if prev_foll_state is not None and prev_foll_state == foll_state: | |
| edge_attributes['option'][-1] += f', {option}' | |
| else: | |
| edge = ( | |
| f'{state} (Stage {level})', | |
| f'{foll_state} (Stage {level + 1})' | |
| ) | |
| edges.append(edge) | |
| edge_attributes['question'].append(question_str) | |
| edge_attributes['option'].append(option) | |
| prev_foll_state = foll_state | |
| graph.add_vertices(nodes, attributes=node_attributes) | |
| graph.add_edges(edges, attributes=edge_attributes) | |
| layout = graph.layout('rt') | |
| # adapted from https://plotly.com/python/tree-plots/ | |
| position = {k: layout[k] for k in range(n_vertices)} | |
| Y = [layout[k][1] for k in range(n_vertices)] | |
| M = max(Y) | |
| E = [e.tuple for e in graph.es] # list of edges | |
| L = len(position) | |
| Xn = [position[k][0] for k in range(L)] | |
| Yn = [2*M-position[k][1] for k in range(L)] | |
| Xe = [] | |
| Ye = [] | |
| # for labelling edges | |
| X_edge_nodes = [] | |
| Y_edge_nodes = [] | |
| for edge in E: | |
| Xe+=[position[edge[0]][0],position[edge[1]][0], None] | |
| Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None] | |
| X_edge_nodes.append((position[edge[0]][0] + position[edge[1]][0]) / 2) | |
| Y_edge_nodes.append((2*M-position[edge[0]][1] + 2*M-position[edge[1]][1]) / 2) | |
| node_labels = [ | |
| node.replace('_', ' ') + '\n\n' + \ | |
| description | |
| for node, description in zip(nodes, node_attributes['description']) | |
| ] | |
| node_labels = pd.DataFrame(node_labels, columns=['label']) | |
| node_labels['label'] = node_labels['label'].str.wrap(30)\ | |
| .apply(lambda x: x.replace('\n', '<br>')) | |
| node_labels = node_labels['label'].to_list() | |
| edge_labels = [ | |
| question.replace('_', ' ') + '\n\n[ ' + option.replace('_', ' ') + ' ]' | |
| for question, option in zip( | |
| edge_attributes['question'], edge_attributes['option'] | |
| ) | |
| ] | |
| edge_labels = pd.DataFrame(edge_labels, columns=['label']) | |
| edge_labels['label'] = edge_labels['label'].str.wrap(30)\ | |
| .apply(lambda x: x.replace('\n', '<br>')) | |
| edge_labels = edge_labels['label'].to_list() | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=Xe, y=Ye, | |
| mode='lines', | |
| line=dict(color='rgb(210,210,210)', width=1), | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=Xn, y=Yn, | |
| mode='markers', | |
| marker=dict( | |
| symbol='circle-dot', size=18, color='#6175c1', | |
| line=dict(color='rgb(50,50,50)', width=1) | |
| ), | |
| text=node_labels, | |
| hoverinfo='text', | |
| opacity=0.8 | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=X_edge_nodes, y=Y_edge_nodes, | |
| mode='markers', | |
| marker=dict( | |
| symbol='circle-dot', size=0, color="#42c744", | |
| line=dict(color='rgb(50,50,50)', width=0) | |
| ), | |
| text=edge_labels, | |
| hoverinfo='text', | |
| opacity=0 | |
| )) | |
| axis = dict( | |
| showline=False, | |
| zeroline=False, | |
| showgrid=False, | |
| showticklabels=False, | |
| ) | |
| fig.update_layout( | |
| showlegend=False, | |
| xaxis=axis, | |
| yaxis=axis | |
| ) | |
| return gr.Plot(fig, visible=True) | |
| def control_screen_widgets() -> List[Union[gr.Row, gr.Text]]: | |
| return [categories_row, category_text, graph_row, graph_plot] + \ | |
| category_buttons | |
| def control_screen( | |
| is_visible: bool | |
| ) -> Dict[Union[gr.Plot, gr.Row, gr.Button], Union[gr.update, gr.Row]]: | |
| row_updates = { | |
| categories_row: gr.Row(visible=is_visible), | |
| graph_row: gr.Row(visible=is_visible) | |
| } | |
| other_update = { | |
| graph_plot: update_graph() if is_visible else gr.update(visible=False), | |
| category_text: gr.update(visible=is_visible) | |
| } | |
| category_button_updates = update_categories() | |
| return row_updates | other_update | category_button_updates | |
| def render(): | |
| categories_row.render() | |
| with categories_row: | |
| for button in category_buttons: | |
| button.render() | |
| button.click( | |
| fn=on_category_click, | |
| inputs=[button], | |
| outputs=[category_text] | |
| ) | |
| category_text.render() | |
| graph_row.render() | |
| with graph_row: | |
| graph_plot.render() | |