Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly | |
| import plotly.graph_objects as go | |
| from assets.color import color_dict | |
| from assets.content import KEYPOINT_DISTRIBUTION, DIFFICULTY_DISTRIBUTION | |
| from assets.path import SEASON | |
| def read_testset(season): | |
| return pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) | |
| def build_keypoint_plot(dataset): | |
| labels, parents, values, colors = {}, [], [], [] | |
| for categories, count in dataset['categories'].value_counts().items(): | |
| for category in categories: | |
| parent = "" | |
| for keypoint in category: | |
| if not keypoint: | |
| keypoint = "未分类" | |
| if keypoint not in labels: | |
| labels[keypoint] = len(labels) | |
| values.append(0) | |
| parents.append(parent) | |
| colors.append(color_dict[category[0]]) | |
| values[labels[keypoint]] += count | |
| parent = keypoint | |
| fig = go.Figure(go.Sunburst( | |
| labels=list(labels), | |
| parents=parents, | |
| values=values, | |
| branchvalues="total", | |
| insidetextorientation='radial', | |
| marker={"colors": colors} | |
| )) | |
| return fig | |
| def build_difficulty_plot(dataset): | |
| xs, ys = [], [] | |
| for x, y in dataset['difficulty'].value_counts().sort_index().items(): | |
| xs.append(x) | |
| ys.append(y) | |
| fig = go.Figure([go.Bar(x=xs, y=ys, marker={"color": ys, "colorscale": "Viridis", | |
| "colorbar": {"title": "Total"}})]) | |
| fig.update_layout(yaxis=dict(type='log')) | |
| return fig | |
| def build_plot(season): | |
| dataset = pd.read_json(os.path.join("results", SEASON[season], "test_dataset.json")) | |
| return build_keypoint_plot(dataset), build_difficulty_plot(dataset) | |
| def create_data(top_components): | |
| k_fig, d_fig = build_plot("latest") | |
| with gr.Tab("All data"): | |
| with gr.Row(): | |
| all_keypoint_plot = gr.Plot( | |
| plotly.io.from_json(KEYPOINT_DISTRIBUTION), | |
| label="Keypoint Distribution") | |
| all_difficulty_plot = gr.Plot( | |
| plotly.io.from_json(DIFFICULTY_DISTRIBUTION), | |
| label="Difficulty Distribution") | |
| with gr.Tab("Test Data"): | |
| with gr.Row(): | |
| test_keypoint_plot = gr.Plot(k_fig, label="Keypoint Distribution") | |
| test_difficulty_plot = gr.Plot(d_fig, label="Difficulty Distribution") | |
| return {"all_keypoint": all_keypoint_plot, "all_difficulty": all_difficulty_plot, | |
| "test_keypoint": test_keypoint_plot, "test_difficulty": test_difficulty_plot} | |