Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| import pandas as pd | |
| # Hugging Face Colors | |
| fillcolor = "#FFD21E" | |
| line_color = "#FF9D00" | |
| fill_color_list = [fillcolor, "#F05998", "#40BAF0"] | |
| line_color_list = [line_color, "#5E233C", "#194A5E"] | |
| # opacity of the plot | |
| opacity = 0.75 | |
| # categories to show radar chart | |
| categories = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU"] | |
| # Dataset columns | |
| columns = ["index","model_name", "model_dtype", "ARC", "HellaSwag", "TruthfulQA", | |
| "Winogrande", "GSM8K","MMLU", "Average"] | |
| #@st.cache_data | |
| def plot_radar_chart_index(dataframe: pd.DataFrame, index: int, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): | |
| """ | |
| plot the index-th row of the dataframe | |
| Arguments: | |
| dataframe: a pandas DataFrame | |
| index: the index of the row we want to plot | |
| categories: the list of the metrics | |
| fillcolor: a string specifying the color to fill the area | |
| line_color: a string specifying the color of the lines in the graph | |
| """ | |
| fig = go.Figure() | |
| data = dataframe.loc[index,categories].to_numpy()*100 | |
| data = data.astype(float) | |
| # rounding data | |
| data = data.round(decimals = 2) | |
| # add data to close the area of the radar chart | |
| data = np.append(data, data[0]) | |
| categories_theta = categories.copy() | |
| categories_theta.append(categories[0]) | |
| model_name = dataframe.loc[index,"model_name"] | |
| #print("Printing data ", data, " for ", model_name) | |
| fig.add_trace(go.Scatterpolar( | |
| r=data, | |
| theta=categories_theta, | |
| fill='toself', | |
| fillcolor = fillcolor, | |
| opacity = opacity, | |
| line=dict(color = line_color), | |
| name= model_name | |
| )) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 100.] | |
| )), | |
| showlegend=False | |
| ) | |
| return fig | |
| #@st.cache_data | |
| def plot_radar_chart_name(dataframe: pd.DataFrame, model_name: str, categories: list = categories, fillcolor: str = fillcolor, line_color:str = line_color): | |
| """ | |
| plot the results of the model named model_name row of the dataframe | |
| Arguments: | |
| dataframe: a pandas DataFrame | |
| model_name: a string stating the name of the model | |
| categories: the list of the metrics | |
| fillcolor: a string specifying the color to fill the area | |
| line_color: a string specifying the color of the lines in the graph | |
| """ | |
| fig = go.Figure() | |
| data = dataframe[dataframe["model_name"] == model_name][categories].to_numpy()*100 | |
| data = data.astype(float) | |
| # rounding data | |
| data = data.round(decimals = 2) | |
| # add data to close the area of the radar chart | |
| data = np.append(data, data[0]) | |
| categories_theta = categories.copy() | |
| categories_theta.append(categories[0]) | |
| model_name = model_name | |
| #print("Printing data ", data, " for ", model_name) | |
| fig.add_trace(go.Scatterpolar( | |
| r=data, | |
| theta=categories_theta, | |
| fill='toself', | |
| fillcolor = fillcolor, | |
| opacity = opacity, | |
| line=dict(color = line_color), | |
| name= model_name | |
| )) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 100.] | |
| )), | |
| showlegend=False | |
| ) | |
| return fig | |
| #@st.cache_data | |
| def plot_radar_chart_rows(rows: object, columns:list = columns, categories: list = categories, fillcolor_list: str = fill_color_list, line_color_list:str = line_color_list): | |
| """ | |
| plot the results of the model selected by the checkbox | |
| Arguments: | |
| rows: an iterable whose elements are dicts with columns as their keys | |
| columns: the list of the columns to use | |
| categories: the list of the metrics | |
| fillcolor: a string specifying the color to fill the area | |
| line_color: a string specifying the color of the lines in the graph | |
| """ | |
| fig = go.Figure() | |
| dataset = pd.DataFrame(rows, columns=columns) | |
| data = dataset[categories].to_numpy() | |
| data = data.astype(float) | |
| showLegend = False | |
| if len(rows) > 1: | |
| showLegend = True | |
| # add data to close the area of the radar chart | |
| data = np.append(data, data[:,0].reshape((-1,1)), axis=1) | |
| categories_theta = categories.copy() | |
| categories_theta.append(categories[0]) | |
| opacity = 0.75 | |
| for i in range(len(dataset)): | |
| colors = fillcolor_list[i] | |
| fig.add_trace(go.Scatterpolar( | |
| r=data[i,:], | |
| theta=categories_theta, | |
| fill='toself', | |
| fillcolor = colors, | |
| opacity = opacity, | |
| line=dict(color = line_color_list[i]), | |
| name= dataset.loc[i,"model_name"] | |
| )) | |
| fig.update_layout( | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| range=[0, 100.] | |
| )), | |
| showlegend=showLegend | |
| ) | |
| opacity -= .2 | |
| return fig |