Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import plotly.graph_objects as go | |
| import dash | |
| from dash import dcc, html, Input, Output | |
| # 创建 Dash 应用 | |
| app = dash.Dash(__name__) | |
| # 文件名称列表 | |
| file_names = [ | |
| 'GPT-4o_statistics.txt', | |
| 'GPT-4o-mini_statistics.txt', | |
| 'Llama-3.1-8B-ft_statistics.txt', | |
| 'Llama-3.1-8B_statistics.txt', | |
| 'Llama-3.1-70B_statistics.txt', | |
| 'Mixtral-8x7B_statistics.txt', | |
| 'Qwen2-72B_statistics.txt', | |
| 'Qwen2-7B_statistics.txt', | |
| 'Llama-2-7b-hf_statistics.txt', | |
| ] | |
| with open("./test_results_report/GPT-4o_statistics.txt", "r") as f: | |
| results = json.load(f) | |
| keys = list(results["exact_match"].keys()) | |
| def load_data(file_name, main_metric="exact_match", r=(0, len(keys))): | |
| tasks = [] | |
| well_learned_digit = [] | |
| has_performance_digit = [] | |
| in_domain = [] | |
| out_domain = [] | |
| short_range = [] | |
| medium_range = [] | |
| long_range = [] | |
| very_long_range = [] | |
| with open(f"./test_results_report/{file_name}", "r") as f: | |
| stats = json.load(f) | |
| stats_exm = stats[main_metric] | |
| for key in keys[r[0]:r[1]]: | |
| words = key.split("_") | |
| domain_3 = words.pop() | |
| domain_2 = words.pop() | |
| domain_1 = words.pop() | |
| task = " ".join(list(map(str.capitalize, words))) | |
| tasks.append(f"{task}<br />{domain_1}") | |
| for metric in ["well_learned_digit", "has_performance_digit", "in_domain", "out_domain", "short_range", "medium_range", "long_range", "very_long_range"]: | |
| eval(f"{metric}.append(stats_exm['{key}']['{metric}'])") | |
| return tasks, well_learned_digit, has_performance_digit, in_domain, out_domain, short_range, medium_range, long_range, very_long_range | |
| # 加载任务列表 | |
| intTasks = ["Add", "Sub", "Max", "Max Hard", "Multiply Hard", "Multiply Easy", "Digit Max", "Digit Add", "Get Digit", "Length", "Truediv", "Floordiv", "Mod", "Mod Easy", "Count", "Sig", "To Scient"] | |
| floatTasks = ["Add", "Sub", "Max", "Max Hard", "Multiply Hard", "Multiply Easy", "Digit Max", "Digit Add", "Get Digit", "Length", "To Scient"] | |
| fractionTasks = ["Add", "Add Easy", "Sub", "Max", "Multiply Hard", "Multiply Easy", "Truediv", "To Float"] | |
| sciTasks = ["Add", "Sub", "Max", "Max Hard", "Multiply Hard", "Multiply Easy", "To Float"] | |
| tasks, well_learned_digit, has_performance_digit, in_domain, out_domain, short_range, medium_range, long_range, very_long_range = load_data("GPT-4o_statistics.txt") | |
| # 去重并排序 | |
| unique_tasks = sorted(list(set(tasks))) | |
| def plot(main_metric, selected_files, selected_metrics, selected_tasks, r): | |
| colors = ["#2C6344", "#5F9C61", "#A4C97C", "#61496D", "#B092B6", "#CAC1D4", "#308192", "#E38D26", "#F1CC74", "#C74D26", "#5EA7B8", "#AED2E2"] | |
| colors.reverse() | |
| fig = go.Figure() | |
| for idx, file_name in enumerate(selected_files): | |
| tasks, well_learned_digit, has_performance_digit, in_domain, out_domain, short_range, medium_range, long_range, very_long_range = load_data(file_name, main_metric=main_metric, r=r) | |
| tasks_new = [] | |
| performance = [] | |
| tasks_old = [] | |
| for i, task in enumerate(tasks): | |
| if task in selected_tasks: | |
| tasks_new += [task] * len(selected_metrics) | |
| tasks_old += [task] | |
| for selected_metric in selected_metrics: | |
| performance += [eval(selected_metric)[i]] | |
| fig.add_trace(go.Bar( | |
| x=[tasks_new, ["S", "M", "L", "XL"] * len(tasks_old)], | |
| y=performance, | |
| name=file_name[:-15], | |
| marker_color=colors[idx % len(colors)] | |
| )) | |
| fig.update_layout( | |
| barmode='group', | |
| xaxis_tickangle=-45, | |
| template="ggplot2", | |
| autosize=False, | |
| width=1500, | |
| height=400, | |
| xaxis=dict(showgrid=False), | |
| title=" ".join(list(map(str.capitalize, main_metric.split("_")))), | |
| margin=dict(l=20, r=10, t=80, b=20), | |
| ) | |
| return fig | |
| # 定义应用程序布局 | |
| app.layout = html.Div([ | |
| html.H1("NUPA Performance", style={"textAlign": "center", "marginBottom": "20px"}), | |
| # Metric 选择单选框 | |
| html.Div([ | |
| html.Label("Select Metric:", style={"fontWeight": "bold", "marginRight": "10px"}), | |
| dcc.RadioItems( | |
| id='metric-selector', | |
| options=[ | |
| {'label': 'Exact Match', 'value': 'exact_match'}, | |
| {'label': 'Digit Match', 'value': 'digit_match'}, | |
| {'label': 'Dlength', 'value': 'dlength'} | |
| ], | |
| value='exact_match', # 默认值 | |
| inline=True, | |
| style={"marginBottom": "20px"} | |
| ), | |
| ], style={"padding": "10px", "border": "1px solid #ccc", "borderRadius": "5px", "marginBottom": "20px"}), | |
| # 文件选择复选框 | |
| html.Div([ | |
| html.Label("Select Models:", style={"fontWeight": "bold", "marginRight": "10px"}), | |
| dcc.Checklist( | |
| id='file-selector', | |
| options=[{'label': file_name[:-15], 'value': file_name} for file_name in file_names], | |
| value=['GPT-4o_statistics.txt', 'Llama-3.1-8B-ft_statistics.txt', 'Mixtral-8x7B_statistics.txt', 'Qwen2-72B_statistics.txt'], | |
| inline=True, | |
| style={"marginBottom": "20px"} | |
| ), | |
| ], style={"padding": "10px", "border": "1px solid #ccc", "borderRadius": "5px", "marginBottom": "20px"}), | |
| # 任务选择复选框(按组分组) | |
| html.Div([ | |
| html.H4("Integer Tasks", style={"fontWeight": "bold", "marginTop": "20px"}), | |
| dcc.Checklist( | |
| id='int-task-selector', | |
| options=[{'label': task, 'value': task + '<br />' + 'Integer'} for task in intTasks], | |
| value=['Add<br />Integer'], | |
| inline=True, | |
| style={"marginBottom": "10px"} | |
| ), | |
| html.H4("Float Tasks", style={"fontWeight": "bold", "marginTop": "20px"}), | |
| dcc.Checklist( | |
| id='float-task-selector', | |
| options=[{'label': task, 'value': task + '<br />' + 'Float'} for task in floatTasks], | |
| value=['Add<br />Float'], | |
| inline=True, | |
| style={"marginBottom": "10px"} | |
| ), | |
| html.H4("Fraction Tasks", style={"fontWeight": "bold", "marginTop": "20px"}), | |
| dcc.Checklist( | |
| id='fraction-task-selector', | |
| options=[{'label': task, 'value': task + '<br />' + 'Fraction'} for task in fractionTasks], | |
| value=['Add<br />Fraction'], | |
| inline=True, | |
| style={"marginBottom": "10px"} | |
| ), | |
| html.H4("Scientific Tasks", style={"fontWeight": "bold", "marginTop": "20px"}), | |
| dcc.Checklist( | |
| id='sci-task-selector', | |
| options=[{'label': task, 'value': task + '<br />' + 'ScientificNotation'} for task in sciTasks], | |
| value=['Add<br />ScientificNotation'], | |
| inline=True, | |
| style={"marginBottom": "10px"} | |
| ), | |
| ], style={"padding": "10px", "border": "1px solid #ccc", "borderRadius": "5px", "marginBottom": "20px"}), | |
| # 显示图表 | |
| dcc.Graph(id='performance-plot'), | |
| ], style={"maxWidth": "1200px", "margin": "0 auto"}) | |
| # 定义回调函数以更新图表 | |
| def update_figure(main_metric, selected_files, selected_int_tasks, selected_float_tasks, selected_fraction_tasks, selected_sci_tasks): | |
| selected_metrics = ["short_range", "medium_range", "long_range", "very_long_range"] | |
| selected_tasks = selected_int_tasks + selected_float_tasks + selected_fraction_tasks + selected_sci_tasks | |
| r = (0, 42) # 使用示例范围 | |
| return plot(main_metric, selected_files, selected_metrics, selected_tasks, r) | |
| # 运行应用程序 | |
| if __name__ == '__main__': | |
| app.run(debug=True, host='0.0.0.0', port=7860) | |