import gradio as gr import pandas as pd import plotly.graph_objects as go from igraph import Graph, EdgeSeq from typing import Dict, List, Union import src.agents.coordinator as C categories_row = gr.Row( visible=False, key='categories_row', preserved_by_key='key' ) graph_row = gr.Row( visible=False, key='graph_row', preserved_by_key='key' ) category_buttons = [ gr.Button( value='', visible=False, key=f'category_button_{i + 1}', preserved_by_key='key' ) for i in range(16) ] category_text = gr.Text( value='Select a category!', label='', visible=False, interactive=False, key='category_text', preserved_by_key='key' ) graph_plot = gr.Plot( visible=False, key='graph_plot', preserved_by_key='key' ) def control_screen(is_visible: bool) -> Dict[gr.Row, gr.Row]: return { categories_row: gr.Row(visible=is_visible), graph_row: gr.Row(visible=is_visible), category_text: gr.update(visible=is_visible) } def update_categories() -> Dict[Union[gr.Button, gr.Text], gr.update]: button_updates = { button: gr.update( value=category.name if C.categories_seen[category.name] \ else '???', visible=True, interactive=C.categories_seen[category.name], variant='primary' if C.categories_seen[category.name] \ else 'secondary' ) for button, category in zip(category_buttons, C.categories) } return { category_text: gr.update(visible=True, value='Select a category!') } | button_updates def on_category_click(category_name: str) -> str: for category in C.categories: if category.name != category_name: continue return category.name + '\n\n' + \ category.description + '\n\n' + \ category.traits def update_graph(): n_vertices = sum([ len(states_) for level, states_ in C.states.items() if level <= C.num_questions_ ]) graph = Graph(directed=True) nodes = [] node_attributes = {'description': []} for level in range(1, C.num_questions_ + 1): for name, description in C.states[level].items(): node_name = f'{name} (Stage {level})' nodes.append(node_name) node_attributes['description'].append(description) edges = [] edge_attributes = { 'question': [], 'option': [] } for level in range(1, C.num_questions_): for state, question_uuid in C.state_question_map[level].items(): question = C.questions[question_uuid] question_str = question.question options = question.options prev_foll_state = None for option, foll_state in sorted(options.items(), key=lambda x: x[1]): if prev_foll_state is not None and prev_foll_state == foll_state: edge_attributes['option'][-1] += f', {option}' else: edge = ( f'{state} (Stage {level})', f'{foll_state} (Stage {level + 1})' ) edges.append(edge) edge_attributes['question'].append(question_str) edge_attributes['option'].append(option) prev_foll_state = foll_state graph.add_vertices(nodes, attributes=node_attributes) graph.add_edges(edges, attributes=edge_attributes) layout = graph.layout('rt') # adapted from https://plotly.com/python/tree-plots/ position = {k: layout[k] for k in range(n_vertices)} Y = [layout[k][1] for k in range(n_vertices)] M = max(Y) E = [e.tuple for e in graph.es] # list of edges L = len(position) Xn = [position[k][0] for k in range(L)] Yn = [2*M-position[k][1] for k in range(L)] Xe = [] Ye = [] # for labelling edges X_edge_nodes = [] Y_edge_nodes = [] for edge in E: Xe+=[position[edge[0]][0],position[edge[1]][0], None] Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None] X_edge_nodes.append((position[edge[0]][0] + position[edge[1]][0]) / 2) Y_edge_nodes.append((2*M-position[edge[0]][1] + 2*M-position[edge[1]][1]) / 2) node_labels = [ node.replace('_', ' ') + '\n\n' + \ description for node, description in zip(nodes, node_attributes['description']) ] node_labels = pd.DataFrame(node_labels, columns=['label']) node_labels['label'] = node_labels['label'].str.wrap(30)\ .apply(lambda x: x.replace('\n', '
')) node_labels = node_labels['label'].to_list() edge_labels = [ question.replace('_', ' ') + '\n\n[ ' + option.replace('_', ' ') + ' ]' for question, option in zip( edge_attributes['question'], edge_attributes['option'] ) ] edge_labels = pd.DataFrame(edge_labels, columns=['label']) edge_labels['label'] = edge_labels['label'].str.wrap(30)\ .apply(lambda x: x.replace('\n', '
')) edge_labels = edge_labels['label'].to_list() fig = go.Figure() fig.add_trace(go.Scatter( x=Xe, y=Ye, mode='lines', line=dict(color='rgb(210,210,210)', width=1), )) fig.add_trace(go.Scatter( x=Xn, y=Yn, mode='markers', marker=dict( symbol='circle-dot', size=18, color='#6175c1', line=dict(color='rgb(50,50,50)', width=1) ), text=node_labels, hoverinfo='text', opacity=0.8 )) fig.add_trace(go.Scatter( x=X_edge_nodes, y=Y_edge_nodes, mode='markers', marker=dict( symbol='circle-dot', size=0, color="#42c744", line=dict(color='rgb(50,50,50)', width=0) ), text=edge_labels, hoverinfo='text', opacity=0 )) axis = dict( showline=False, zeroline=False, showgrid=False, showticklabels=False, ) fig.update_layout( showlegend=False, xaxis=axis, yaxis=axis ) return gr.Plot(fig, visible=True) def control_screen_widgets() -> List[Union[gr.Row, gr.Text]]: return [categories_row, category_text, graph_row, graph_plot] + \ category_buttons def control_screen( is_visible: bool ) -> Dict[Union[gr.Plot, gr.Row, gr.Button], Union[gr.update, gr.Row]]: row_updates = { categories_row: gr.Row(visible=is_visible), graph_row: gr.Row(visible=is_visible) } other_update = { graph_plot: update_graph() if is_visible else gr.update(visible=False), category_text: gr.update(visible=is_visible) } category_button_updates = update_categories() return row_updates | other_update | category_button_updates def render(): categories_row.render() with categories_row: for button in category_buttons: button.render() button.click( fn=on_category_click, inputs=[button], outputs=[category_text] ) category_text.render() graph_row.render() with graph_row: graph_plot.render()