Spaces:
Sleeping
Sleeping
adding feature of changing the order of the metrics on the circle of the chart
Browse files- src/display.py +41 -8
- src/load_data.py +31 -1
src/display.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
#from src.load_data import load_dataframe, sort_by
|
| 3 |
-
#from src.plot import plot_radar_chart_index, plot_radar_chart_name
|
| 4 |
-
#from st_aggrid import GridOptionsBuilder, AgGrid
|
| 5 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
| 6 |
import streamlit as st
|
| 7 |
-
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name
|
| 8 |
from .plot import plot_radar_chart_name, plot_radar_chart_rows
|
| 9 |
|
| 10 |
|
|
@@ -32,8 +29,34 @@ def display_app():
|
|
| 32 |
|
| 33 |
|
| 34 |
name = st.text_input(label = ":mag: Search by name")
|
|
|
|
|
|
|
| 35 |
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
|
| 36 |
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
len_name_input = len(name)
|
| 38 |
if len_name_input > 0:
|
| 39 |
dataframe_by_search = search_by_name(name)
|
|
@@ -79,12 +102,22 @@ def display_app():
|
|
| 79 |
|
| 80 |
with column2:
|
| 81 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
|
| 82 |
-
figure =
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
st.plotly_chart(figure, use_container_width=False)
|
|
|
|
| 85 |
else:
|
| 86 |
if len(subdata)>0:
|
| 87 |
-
figure =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
st.plotly_chart(figure, use_container_width=True)
|
| 89 |
|
| 90 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
|
|
|
|
| 1 |
+
|
|
|
|
|
|
|
|
|
|
| 2 |
from st_aggrid import GridOptionsBuilder, AgGrid
|
| 3 |
import streamlit as st
|
| 4 |
+
from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories
|
| 5 |
from .plot import plot_radar_chart_name, plot_radar_chart_rows
|
| 6 |
|
| 7 |
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
name = st.text_input(label = ":mag: Search by name")
|
| 32 |
+
|
| 33 |
+
#Sidebar configurations
|
| 34 |
selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=0)
|
| 35 |
st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.")
|
| 36 |
+
ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.",
|
| 37 |
+
placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU")
|
| 38 |
+
|
| 39 |
+
ordering_metrics = ordering_metrics.replace(" ", "")
|
| 40 |
+
ordering_metrics = ordering_metrics.split(",")
|
| 41 |
+
|
| 42 |
+
st.sidebar.markdown("""
|
| 43 |
+
As a reminder, here are the different metrics:
|
| 44 |
+
* ARC
|
| 45 |
+
* GSM8K
|
| 46 |
+
* TruthfulQA
|
| 47 |
+
* Winogrande
|
| 48 |
+
* HellaSwag
|
| 49 |
+
* MMLU
|
| 50 |
+
""")
|
| 51 |
+
st.sidebar.markdown("""
|
| 52 |
+
If there are **typos** in the name of the metrics, or the number of metrics
|
| 53 |
+
is **different of six**, there will be no effect on the chart and the
|
| 54 |
+
default ordering will be used.
|
| 55 |
+
""")
|
| 56 |
+
|
| 57 |
+
valid_categories = validate_categories(ordering_metrics)
|
| 58 |
+
|
| 59 |
+
# Search bar
|
| 60 |
len_name_input = len(name)
|
| 61 |
if len_name_input > 0:
|
| 62 |
dataframe_by_search = search_by_name(name)
|
|
|
|
| 102 |
|
| 103 |
with column2:
|
| 104 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0:
|
| 105 |
+
figure = None
|
| 106 |
+
if valid_categories:
|
| 107 |
+
|
| 108 |
+
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3], categories = ordering_metrics)
|
| 109 |
+
else:
|
| 110 |
+
figure = plot_radar_chart_rows(rows=grid_response['selected_rows'][:3])
|
| 111 |
st.plotly_chart(figure, use_container_width=False)
|
| 112 |
+
|
| 113 |
else:
|
| 114 |
if len(subdata)>0:
|
| 115 |
+
figure = None
|
| 116 |
+
if valid_categories:
|
| 117 |
+
figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name)
|
| 118 |
+
else:
|
| 119 |
+
figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name)
|
| 120 |
+
|
| 121 |
st.plotly_chart(figure, use_container_width=True)
|
| 122 |
|
| 123 |
if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 1:
|
src/load_data.py
CHANGED
|
@@ -54,4 +54,34 @@ def search_by_name(name: str) -> pd.DataFrame:
|
|
| 54 |
"""
|
| 55 |
dataframe = load_dataframe()
|
| 56 |
indexes = dataframe["model_name"].str.contains(name)
|
| 57 |
-
return dataframe[indexes]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
dataframe = load_dataframe()
|
| 56 |
indexes = dataframe["model_name"].str.contains(name)
|
| 57 |
+
return dataframe[indexes]
|
| 58 |
+
|
| 59 |
+
def validate_categories(categories: list) -> bool:
|
| 60 |
+
"""
|
| 61 |
+
validate a list of categories to the columns in the dataframe
|
| 62 |
+
Arguments:
|
| 63 |
+
- categories: a list of categories for the ordering of the columns in the dataframe
|
| 64 |
+
|
| 65 |
+
This expects a list with six elements that should be (not necessary in order):
|
| 66 |
+
- ARC
|
| 67 |
+
- GSM8K
|
| 68 |
+
- TruthfulQA
|
| 69 |
+
- Winogrande
|
| 70 |
+
- HellaSwag
|
| 71 |
+
- MMLU
|
| 72 |
+
|
| 73 |
+
Returns
|
| 74 |
+
- True if the list has the right number of element and right elements
|
| 75 |
+
- False otherwise
|
| 76 |
+
"""
|
| 77 |
+
valid_categories = False
|
| 78 |
+
if len(categories) == 6:
|
| 79 |
+
if ("ARC" in categories and "GSM8K" in categories and "TruthfulQA" in categories
|
| 80 |
+
and "Winogrande" in categories and "HellaSwag" in categories and "MMLU" in categories):
|
| 81 |
+
valid_categories = True
|
| 82 |
+
else:
|
| 83 |
+
valid_categories = False
|
| 84 |
+
else:
|
| 85 |
+
valid_categories = False
|
| 86 |
+
|
| 87 |
+
return valid_categories
|