Spaces:
Sleeping
Sleeping
| from st_aggrid import GridOptionsBuilder, AgGrid | |
| from streamlit_searchbox import st_searchbox | |
| import streamlit as st | |
| from .load_data import load_dataframe, sort_by, show_dataframe_top, search_by_name, validate_categories | |
| from .plot import plot_radar_chart_name, plot_radar_chart_rows | |
| def display_app(): | |
| st.markdown("# Open LLM Leaderboard Viz") | |
| st.markdown("## Some explanations") | |
| st.markdown("This is a visualization of the results in [open-llm-leaderboard/results](https://huggingface.co/datasets/open-llm-leaderboard/results)") | |
| st.markdown("To select a model, click on the checkbox beside its name, or search it by its name in the search boxes **Model 1, Model 2, or Model 3** bellow.") | |
| st.markdown("You can select up to three models using the search boxes and/or the checkboxes.") | |
| st.markdown("""In the case you use both the search boxes and the checkboxes, the search boxes will take precedence over the checkboxes, | |
| i.e. the models searched using the search boxes will be prioritized over the ones selected using the checkboxes. | |
| Please, search models using the search boxes first, and then use the checkboxes. | |
| """) | |
| st.markdown("This app displays the top 100 models by default, but you can change that using the number input in the sidebar.") | |
| st.markdown("By default as well, the maximum number of row you can display is 500, it is due to the problem with st_aggrid component loading.") | |
| st.markdown("If your model doesn't show up, please search it by its name.") | |
| dataframe = load_dataframe() | |
| categories_display = ["ARC", "GSM8K", "TruthfulQA", "Winogrande", "HellaSwag", "MMLU", "Average"] | |
| st.markdown("## Leaderboard") | |
| sort_selection = st.selectbox(label = "Sort by:", options = list(dataframe.columns.difference(["model_dtype"])), index = 1) | |
| d_type_options = ["all", "torch.bfloat16", "torch.float16", "4bit", "8bit"] | |
| d_type = st.radio(label = "Filter by dtype", options = d_type_options, index = 0, horizontal = True) | |
| number_of_row = st.sidebar.number_input("Number of top rows to display", min_value=100, max_value=500, value="min", step=100) | |
| ascending = True | |
| if sort_selection is None: | |
| sort_selection = "model_name" | |
| ascending = True | |
| elif sort_selection == "model_name": | |
| ascending = True | |
| else: | |
| ascending = False | |
| # Dynamic search boxes | |
| def search_model(model_name: str): | |
| model_list = None | |
| if model_name is not None or model_name != "": | |
| models = dataframe["model_name"].str.contains(model_name) | |
| model_list = dataframe["model_name"][models] | |
| else: | |
| model_list = dataframe["model_name"] | |
| return model_list | |
| model_list = [] | |
| #Sidebar configurations | |
| selection_mode = st.sidebar.radio(label= "Selection mode for the rows", options = ["single", "multiple"], index=1) | |
| st.sidebar.write("In multiple mode, you can select up to three models. If you select more than three models, only the first three will be displayed and plotted.") | |
| ordering_metrics = st.sidebar.text_input(label = "Order of the metrics on the circle, counter-clock wise, beginning at 3 o'clock.", | |
| placeholder = "ARC, GSM8K, TruthfulQA, Winogrande, HellaSwag, MMLU") | |
| ordering_metrics = ordering_metrics.replace(" ", "") | |
| ordering_metrics = ordering_metrics.split(",") | |
| st.sidebar.markdown(""" | |
| As a reminder, here are the different metrics: | |
| * ARC | |
| * GSM8K | |
| * TruthfulQA | |
| * Winogrande | |
| * HellaSwag | |
| * MMLU | |
| """) | |
| st.sidebar.markdown(""" | |
| If there are **typos** in the name of the metrics, or the number of metrics | |
| is **different of six**, there will be no effect on the chart and the | |
| default ordering will be used. | |
| """) | |
| valid_categories = validate_categories(ordering_metrics) | |
| dataframe = sort_by(dataframe=dataframe, column_name=sort_selection, ascending= ascending) | |
| if d_type != "all": | |
| dataframe = dataframe[dataframe["model_dtype"] == d_type] | |
| dataframe_display = dataframe.copy() | |
| dataframe_display = show_dataframe_top(number_of_row,dataframe_display) | |
| dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].astype(float) | |
| dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] *100 | |
| dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]] = dataframe_display[["ARC", "HellaSwag", "TruthfulQA", "Winogrande", "GSM8K" ,"MMLU", "Average"]].round(2) | |
| #Infer basic colDefs from dataframe types | |
| gb = GridOptionsBuilder.from_dataframe(dataframe_display) | |
| gb.configure_selection(selection_mode = selection_mode, use_checkbox=True) | |
| gb.configure_grid_options(domLayout='normal') | |
| gridOptions = gb.build() | |
| column1,col3, column2 = st.columns([0.26, 0.05, 0.69], gap = "small") | |
| with column1: | |
| grid_response = AgGrid( | |
| dataframe_display, | |
| gridOptions=gridOptions, | |
| height=300, | |
| width='40%' | |
| ) | |
| model_one = st_searchbox(label = "Model 1", search_function = search_model, key = "model_1", default= None) | |
| model_two = st_searchbox(label = "Model 2", search_function = search_model, key = "model_2", default= None) | |
| model_three = st_searchbox(label = "Model 3", search_function = search_model, key = "model_3", default= None) | |
| if model_one is not None: | |
| row = dataframe[dataframe["model_name"] == model_one] | |
| row[categories_display] = row[categories_display]*100 | |
| model_list.append(row.to_dict("records")[0]) | |
| if model_two is not None: | |
| row = dataframe[dataframe["model_name"] == model_two] | |
| row[categories_display] = row[categories_display]*100 | |
| model_list.append(row.to_dict("records")[0]) | |
| if model_three is not None: | |
| row = dataframe[dataframe["model_name"] == model_three] | |
| row[categories_display] = row[categories_display]*100 | |
| model_list.append(row.to_dict("records")[0]) | |
| subdata = dataframe.head(1) | |
| if len(subdata) > 0: | |
| model_name = subdata["model_name"].values[0] | |
| else: | |
| model_name = "" | |
| with column2: | |
| if grid_response['selected_rows'] is not None and len(grid_response['selected_rows']) > 0: | |
| figure = None | |
| #grid_response is now a Pandas dataframe, we need to | |
| # convert to dict in order to merge with the searchboxes' results | |
| model_list += grid_response['selected_rows'].to_dict("records") | |
| model_list = model_list[:3] | |
| model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) | |
| if valid_categories: | |
| figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) | |
| else: | |
| figure = plot_radar_chart_rows(rows=model_list) | |
| st.plotly_chart(figure, use_container_width=False) | |
| elif len(model_list) > 0: | |
| figure = None | |
| model_list = sorted(model_list, key = lambda x: x["Average"], reverse = True) | |
| if valid_categories: | |
| figure = plot_radar_chart_rows(rows=model_list, categories = ordering_metrics) | |
| else: | |
| figure = plot_radar_chart_rows(rows=model_list) | |
| st.plotly_chart(figure, use_container_width=False) | |
| else: | |
| if len(subdata)>0: | |
| figure = None | |
| if valid_categories: | |
| figure = plot_radar_chart_name(dataframe=subdata, categories = ordering_metrics, model_name=model_name) | |
| else: | |
| figure = plot_radar_chart_name(dataframe=subdata, model_name=model_name) | |
| st.plotly_chart(figure, use_container_width=True) | |
| if len(model_list) > 1: | |
| n_col = len(model_list) if len(model_list) <=3 else 3 | |
| st.markdown("## Models") | |
| columns = st.columns(n_col) | |
| for i in range(n_col): | |
| with columns[i]: | |
| st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[i]["model_name"] , model_list[i]["model_name"])) | |
| st.markdown("**Results:**") | |
| st.markdown(""" | |
| * Average: %s | |
| * ARC: %s | |
| * GSM8K: %s | |
| * TruthfulQA: %s | |
| * Winogrande: %s | |
| * HellaSwag: %s | |
| * MMLU: %s | |
| """ % (round(model_list[i]["Average"],2), | |
| round(model_list[i]["ARC"],2), | |
| round(model_list[i]["GSM8K"],2), | |
| round(model_list[i]["TruthfulQA"],2), | |
| round(model_list[i]["Winogrande"],2), | |
| round(model_list[i]["HellaSwag"],2), | |
| round(model_list[i]["MMLU"],2) | |
| )) | |
| st.markdown("**dtype:** %s" % model_list[i]["model_dtype"]) | |
| elif len(model_list) == 1: | |
| st.markdown("**Model name:** [%s](https://huggingface.co/%s)" % (model_list[0]["model_name"] , model_list[0]["model_name"])) | |
| st.markdown("**Results:**") | |
| st.markdown(""" | |
| * Average: %s | |
| * ARC: %s | |
| * GSM8K: %s | |
| * TruthfulQA: %s | |
| * Winogrande: %s | |
| * HellaSwag: %s | |
| * MMLU: %s | |
| """ % (round(model_list[0]["Average"],2), | |
| round(model_list[0]["ARC"],2), | |
| round(model_list[0]["GSM8K"],2), | |
| round(model_list[0]["TruthfulQA"],2), | |
| round(model_list[0]["Winogrande"],2), | |
| round(model_list[0]["HellaSwag"],2), | |
| round(model_list[0]["MMLU"],2) | |
| )) | |
| st.markdown("**dtype:** %s" % model_list[0]["model_dtype"]) | |
| st.markdown("For more details, hover over the radar chart.") | |
| else: | |
| st.markdown("**Model name:** %s" % model_name) | |
| st.markdown("For more details, select the first model in the list/leaderboard.") | |