Spaces:
Configuration error
Configuration error
| import evaluate | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import ast | |
| from ece import ECE # loads local instead | |
| import matplotlib.pyplot as plt | |
| """ | |
| import seaborn as sns | |
| sns.set_style('white') | |
| sns.set_context("paper", font_scale=1) # 2 | |
| """ | |
| # plt.rcParams['figure.figsize'] = [10, 7] | |
| plt.rcParams["figure.dpi"] = 300 | |
| plt.switch_backend( | |
| "agg" | |
| ) # ; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop | |
| sliders = [ | |
| gr.Slider(0, 100, value=10, label="n_bins"), | |
| gr.Slider( | |
| 0, 100, value=None, label="bin_range", visible=False | |
| ), # DEV: need to have a double slider | |
| gr.Dropdown(choices=["equal-range", "equal-mass"], value="equal-range", label="scheme"), | |
| gr.Dropdown(choices=["upper-edge", "center"], value="upper-edge", label="proxy"), | |
| gr.Dropdown(choices=[1, 2, np.inf], value=1, label="p"), | |
| ] | |
| slider_defaults = [slider.value for slider in sliders] | |
| # example data | |
| df = dict() | |
| df["predictions"] = [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1, 0.2]] | |
| df["references"] = [0, 1, 2] | |
| component = gr.inputs.Dataframe( | |
| headers=["predictions", "references"], col_count=2, datatype="number", type="pandas" | |
| ) | |
| component.value = [ | |
| [[0.6, 0.2, 0.2], 0], | |
| [[0.7, 0.1, 0.2], 2], | |
| [[0, 0.95, 0.05], 1], | |
| ] | |
| sample_data = [[component] + slider_defaults] ##json.dumps(df) | |
| local_path = Path(sys.path[0]) | |
| metric = ECE() | |
| # module = evaluate.load("jordyvl/ece") | |
| # launch_gradio_widget(module) | |
| """ | |
| Switch inputs and compute_fn | |
| """ | |
| def reliability_plot(results): | |
| fig = plt.figure() | |
| ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) | |
| ax2 = plt.subplot2grid((3, 1), (2, 0)) | |
| n_bins = len(results["y_bar"]) | |
| bin_range = [ | |
| results["y_bar"][0] - results["y_bar"][0], | |
| results["y_bar"][-1], | |
| ] # np.linspace(0, 1, n_bins) | |
| # if upper edge then minus binsize; same for center [but half] | |
| ranged = np.linspace(bin_range[0], bin_range[1], n_bins) | |
| ax1.plot( | |
| ranged, | |
| ranged, | |
| color="darkgreen", | |
| ls="dotted", | |
| label="Perfect", | |
| ) | |
| # ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect") | |
| anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0] | |
| bin_freqs = np.zeros(n_bins) | |
| bin_freqs[anindices] = results["bin_freq"] | |
| ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs) | |
| # widths = np.diff(results["y_bar"]) | |
| for j, bin in enumerate(results["y_bar"]): | |
| perfect = results["y_bar"][j] | |
| empirical = results["p_bar"][j] | |
| if np.isnan(empirical): | |
| continue | |
| ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue") | |
| """ | |
| if perfect == empirical: | |
| continue | |
| """ | |
| acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy") | |
| conf_plt = ax2.axvline( | |
| x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence" | |
| ) | |
| ax2.legend(handles=[acc_plt, conf_plt]) | |
| # Bin differences | |
| ax1.set_ylabel("Conditional Expectation") | |
| ax1.set_ylim([-0.05, 1.05]) # respective to bin range | |
| ax1.legend(loc="lower right") | |
| ax1.set_title("Reliability Diagram") | |
| # Bin frequencies | |
| ax2.set_xlabel("Confidence") | |
| ax2.set_ylabel("Count") | |
| ax2.legend(loc="upper left") # , ncol=2 | |
| plt.tight_layout() | |
| return fig | |
| def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p): | |
| # DEV: check on invalid datatypes with better warnings | |
| if isinstance(data, pd.DataFrame): | |
| data.dropna(inplace=True) | |
| predictions = [ | |
| ast.literal_eval(prediction) if not isinstance(prediction, list) else prediction | |
| for prediction in data["predictions"] | |
| ] | |
| references = [reference for reference in data["references"]] | |
| results = metric._compute( | |
| predictions, | |
| references, | |
| n_bins=n_bins, | |
| scheme=scheme, | |
| proxy=proxy, | |
| p=p, | |
| detail=True, | |
| ) | |
| plot = reliability_plot(results) | |
| return results["ECE"], plot # plt.gcf() | |
| outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")] | |
| iface = gr.Interface( | |
| fn=compute_and_plot, | |
| inputs=[component] + sliders, | |
| outputs=outputs, | |
| description=metric.info.description, | |
| article=evaluate.utils.parse_readme(local_path / "README.md"), | |
| title=f"Metric: {metric.name}", | |
| # examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs. | |
| ).launch() | |