sto-rai / src /frontend /progress_screen.py
yiiilonggg's picture
Fix loading json bugs
5ac7eb0
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from igraph import Graph, EdgeSeq
from typing import Dict, List, Union
import src.agents.coordinator as C
categories_row = gr.Row(
visible=False,
key='categories_row',
preserved_by_key='key'
)
graph_row = gr.Row(
visible=False,
key='graph_row',
preserved_by_key='key'
)
category_buttons = [
gr.Button(
value='',
visible=False,
key=f'category_button_{i + 1}',
preserved_by_key='key'
)
for i in range(16)
]
category_text = gr.Text(
value='Select a category!',
label='',
visible=False,
interactive=False,
key='category_text',
preserved_by_key='key'
)
graph_plot = gr.Plot(
visible=False,
key='graph_plot',
preserved_by_key='key'
)
def control_screen(is_visible: bool) -> Dict[gr.Row, gr.Row]:
return {
categories_row: gr.Row(visible=is_visible),
graph_row: gr.Row(visible=is_visible),
category_text: gr.update(visible=is_visible)
}
def update_categories() -> Dict[Union[gr.Button, gr.Text], gr.update]:
button_updates = {
button: gr.update(
value=category.name if C.categories_seen[category.name] \
else '???',
visible=True,
interactive=C.categories_seen[category.name],
variant='primary' if C.categories_seen[category.name] \
else 'secondary'
)
for button, category in zip(category_buttons, C.categories)
}
return {
category_text: gr.update(visible=True, value='Select a category!')
} | button_updates
def on_category_click(category_name: str) -> str:
for category in C.categories:
if category.name != category_name:
continue
return category.name + '\n\n' + \
category.description + '\n\n' + \
category.traits
def update_graph():
n_vertices = sum([
len(states_)
for level, states_ in C.states.items()
if level <= C.num_questions_
])
graph = Graph(directed=True)
nodes = []
node_attributes = {'description': []}
for level in range(1, C.num_questions_ + 1):
for name, description in C.states[level].items():
node_name = f'{name} (Stage {level})'
nodes.append(node_name)
node_attributes['description'].append(description)
edges = []
edge_attributes = {
'question': [],
'option': []
}
for level in range(1, C.num_questions_):
for state, question_uuid in C.state_question_map[level].items():
question = C.questions[question_uuid]
question_str = question.question
options = question.options
prev_foll_state = None
for option, foll_state in sorted(options.items(), key=lambda x: x[1]):
if prev_foll_state is not None and prev_foll_state == foll_state:
edge_attributes['option'][-1] += f', {option}'
else:
edge = (
f'{state} (Stage {level})',
f'{foll_state} (Stage {level + 1})'
)
edges.append(edge)
edge_attributes['question'].append(question_str)
edge_attributes['option'].append(option)
prev_foll_state = foll_state
graph.add_vertices(nodes, attributes=node_attributes)
graph.add_edges(edges, attributes=edge_attributes)
layout = graph.layout('rt')
# adapted from https://plotly.com/python/tree-plots/
position = {k: layout[k] for k in range(n_vertices)}
Y = [layout[k][1] for k in range(n_vertices)]
M = max(Y)
E = [e.tuple for e in graph.es] # list of edges
L = len(position)
Xn = [position[k][0] for k in range(L)]
Yn = [2*M-position[k][1] for k in range(L)]
Xe = []
Ye = []
# for labelling edges
X_edge_nodes = []
Y_edge_nodes = []
for edge in E:
Xe+=[position[edge[0]][0],position[edge[1]][0], None]
Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None]
X_edge_nodes.append((position[edge[0]][0] + position[edge[1]][0]) / 2)
Y_edge_nodes.append((2*M-position[edge[0]][1] + 2*M-position[edge[1]][1]) / 2)
node_labels = [
node.replace('_', ' ') + '\n\n' + \
description
for node, description in zip(nodes, node_attributes['description'])
]
node_labels = pd.DataFrame(node_labels, columns=['label'])
node_labels['label'] = node_labels['label'].str.wrap(30)\
.apply(lambda x: x.replace('\n', '<br>'))
node_labels = node_labels['label'].to_list()
edge_labels = [
question.replace('_', ' ') + '\n\n[ ' + option.replace('_', ' ') + ' ]'
for question, option in zip(
edge_attributes['question'], edge_attributes['option']
)
]
edge_labels = pd.DataFrame(edge_labels, columns=['label'])
edge_labels['label'] = edge_labels['label'].str.wrap(30)\
.apply(lambda x: x.replace('\n', '<br>'))
edge_labels = edge_labels['label'].to_list()
fig = go.Figure()
fig.add_trace(go.Scatter(
x=Xe, y=Ye,
mode='lines',
line=dict(color='rgb(210,210,210)', width=1),
))
fig.add_trace(go.Scatter(
x=Xn, y=Yn,
mode='markers',
marker=dict(
symbol='circle-dot', size=18, color='#6175c1',
line=dict(color='rgb(50,50,50)', width=1)
),
text=node_labels,
hoverinfo='text',
opacity=0.8
))
fig.add_trace(go.Scatter(
x=X_edge_nodes, y=Y_edge_nodes,
mode='markers',
marker=dict(
symbol='circle-dot', size=0, color="#42c744",
line=dict(color='rgb(50,50,50)', width=0)
),
text=edge_labels,
hoverinfo='text',
opacity=0
))
axis = dict(
showline=False,
zeroline=False,
showgrid=False,
showticklabels=False,
)
fig.update_layout(
showlegend=False,
xaxis=axis,
yaxis=axis
)
return gr.Plot(fig, visible=True)
def control_screen_widgets() -> List[Union[gr.Row, gr.Text]]:
return [categories_row, category_text, graph_row, graph_plot] + \
category_buttons
def control_screen(
is_visible: bool
) -> Dict[Union[gr.Plot, gr.Row, gr.Button], Union[gr.update, gr.Row]]:
row_updates = {
categories_row: gr.Row(visible=is_visible),
graph_row: gr.Row(visible=is_visible)
}
other_update = {
graph_plot: update_graph() if is_visible else gr.update(visible=False),
category_text: gr.update(visible=is_visible)
}
category_button_updates = update_categories()
return row_updates | other_update | category_button_updates
def render():
categories_row.render()
with categories_row:
for button in category_buttons:
button.render()
button.click(
fn=on_category_click,
inputs=[button],
outputs=[category_text]
)
category_text.render()
graph_row.render()
with graph_row:
graph_plot.render()