Spaces:
Configuration error
Configuration error
app and ece done
Browse files
app.py
CHANGED
|
@@ -7,6 +7,13 @@ import gradio as gr
|
|
| 7 |
from evaluate.utils import launch_gradio_widget
|
| 8 |
from ece import ECE
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
sliders = [
|
| 12 |
gr.Slider(0, 100, value=10, label="n_bins"),
|
|
@@ -44,16 +51,6 @@ Switch inputs and compute_fn
|
|
| 44 |
"""
|
| 45 |
|
| 46 |
def reliability_plot(results):
|
| 47 |
-
#CE, calibrated_acc, empirical_acc, weights_ece
|
| 48 |
-
#{"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
|
| 49 |
-
import matplotlib.pyplot as plt
|
| 50 |
-
import seaborn as sns
|
| 51 |
-
sns.set_style('white')
|
| 52 |
-
sns.set_context("paper", font_scale=1) # 2
|
| 53 |
-
# plt.rcParams['figure.figsize'] = [10, 7]
|
| 54 |
-
plt.rcParams['figure.dpi'] = 300
|
| 55 |
-
|
| 56 |
-
|
| 57 |
fig = plt.figure()
|
| 58 |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
|
| 59 |
ax2 = plt.subplot2grid((3, 1), (2, 0))
|
|
@@ -65,9 +62,10 @@ def reliability_plot(results):
|
|
| 65 |
] # np.linspace(0, 1, n_bins)
|
| 66 |
# if upper edge then minus binsize; same for center [but half]
|
| 67 |
|
|
|
|
| 68 |
ax1.plot(
|
| 69 |
-
|
| 70 |
-
|
| 71 |
color="darkgreen",
|
| 72 |
ls="dotted",
|
| 73 |
label="Perfect",
|
|
@@ -79,7 +77,7 @@ def reliability_plot(results):
|
|
| 79 |
bin_freqs[anindices] = results["bin_freq"]
|
| 80 |
ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
|
| 81 |
|
| 82 |
-
widths = np.diff(results["y_bar"])
|
| 83 |
for j, bin in enumerate(results["y_bar"]):
|
| 84 |
perfect = results["y_bar"][j]
|
| 85 |
empirical = results["p_bar"][j]
|
|
@@ -87,7 +85,7 @@ def reliability_plot(results):
|
|
| 87 |
if np.isnan(empirical):
|
| 88 |
continue
|
| 89 |
|
| 90 |
-
ax1.bar([perfect], height=[empirical], width=-
|
| 91 |
|
| 92 |
if perfect == empirical:
|
| 93 |
continue
|
|
@@ -137,10 +135,10 @@ def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
|
|
| 137 |
)
|
| 138 |
|
| 139 |
plot = reliability_plot(results)
|
| 140 |
-
return results["ECE"], plt.gcf()
|
| 141 |
|
| 142 |
|
| 143 |
-
outputs = [gr.outputs.Textbox(label="ECE"), gr.
|
| 144 |
|
| 145 |
iface = gr.Interface(
|
| 146 |
fn=compute_and_plot,
|
|
@@ -148,26 +146,5 @@ iface = gr.Interface(
|
|
| 148 |
outputs=outputs,
|
| 149 |
description=metric.info.description,
|
| 150 |
article=metric.info.citation,
|
| 151 |
-
# examples=sample_data
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
# ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
|
| 155 |
-
|
| 156 |
-
iface.launch()
|
| 157 |
-
|
| 158 |
-
# dict = {"ECE": ECE[0], "y_bar": ECE[1], "p_bar": ECE[2], "bin_freq": ECE[3]}
|
| 159 |
-
|
| 160 |
-
# references=[0, 1, 2], predictions=)
|
| 161 |
-
# https://gradio.app/getting_started/#multiple-inputs-and-outputs
|
| 162 |
-
## fix with sliders for all kwargs
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
"""
|
| 166 |
-
DEV: #might be nice to also plot reliability diagram
|
| 167 |
-
have sliders for kwargs :)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
metric = ECE()
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
"""
|
|
|
|
| 7 |
from evaluate.utils import launch_gradio_widget
|
| 8 |
from ece import ECE
|
| 9 |
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
sns.set_style('white')
|
| 13 |
+
sns.set_context("paper", font_scale=1) # 2
|
| 14 |
+
# plt.rcParams['figure.figsize'] = [10, 7]
|
| 15 |
+
plt.rcParams['figure.dpi'] = 300
|
| 16 |
+
plt.switch_backend('agg') #; https://stackoverflow.com/questions/14694408/runtimeerror-main-thread-is-not-in-main-loop
|
| 17 |
|
| 18 |
sliders = [
|
| 19 |
gr.Slider(0, 100, value=10, label="n_bins"),
|
|
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
def reliability_plot(results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
fig = plt.figure()
|
| 55 |
ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
|
| 56 |
ax2 = plt.subplot2grid((3, 1), (2, 0))
|
|
|
|
| 62 |
] # np.linspace(0, 1, n_bins)
|
| 63 |
# if upper edge then minus binsize; same for center [but half]
|
| 64 |
|
| 65 |
+
ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
|
| 66 |
ax1.plot(
|
| 67 |
+
ranged,
|
| 68 |
+
ranged,
|
| 69 |
color="darkgreen",
|
| 70 |
ls="dotted",
|
| 71 |
label="Perfect",
|
|
|
|
| 77 |
bin_freqs[anindices] = results["bin_freq"]
|
| 78 |
ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
|
| 79 |
|
| 80 |
+
#widths = np.diff(results["y_bar"])
|
| 81 |
for j, bin in enumerate(results["y_bar"]):
|
| 82 |
perfect = results["y_bar"][j]
|
| 83 |
empirical = results["p_bar"][j]
|
|
|
|
| 85 |
if np.isnan(empirical):
|
| 86 |
continue
|
| 87 |
|
| 88 |
+
ax1.bar([perfect], height=[empirical], width=-ranged[j], align="edge", color="lightblue")
|
| 89 |
|
| 90 |
if perfect == empirical:
|
| 91 |
continue
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
plot = reliability_plot(results)
|
| 138 |
+
return results["ECE"], plot #plt.gcf()
|
| 139 |
|
| 140 |
|
| 141 |
+
outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
|
| 142 |
|
| 143 |
iface = gr.Interface(
|
| 144 |
fn=compute_and_plot,
|
|
|
|
| 146 |
outputs=outputs,
|
| 147 |
description=metric.info.description,
|
| 148 |
article=metric.info.citation,
|
| 149 |
+
# examples=sample_data; # ValueError: Examples argument must either be a directory or a nested list, where each sublist represents a set of inputs.
|
| 150 |
+
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ece.py
CHANGED
|
@@ -80,7 +80,7 @@ BAD_WORDS_URL = ""
|
|
| 80 |
def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
| 81 |
assert scheme in [
|
| 82 |
"equal-range",
|
| 83 |
-
"equal-
|
| 84 |
], f"This binning scheme {scheme} is not implemented yet"
|
| 85 |
|
| 86 |
if bin_range is None:
|
|
@@ -106,8 +106,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
|
| 106 |
# rightmost entry per equal size group
|
| 107 |
for cur_group in range(n_bins - 1):
|
| 108 |
bin_upper_edges += [max(groups[cur_group])]
|
| 109 |
-
bin_upper_edges += [np.inf] # always +1 for right edges
|
| 110 |
bins = np.array(bin_upper_edges)
|
|
|
|
| 111 |
|
| 112 |
return bins
|
| 113 |
|
|
@@ -201,7 +202,7 @@ def top_1_CE(Y, P, **kwargs):
|
|
| 201 |
n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
|
| 202 |
)
|
| 203 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
| 204 |
-
if
|
| 205 |
return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
|
| 206 |
return CE
|
| 207 |
|
|
|
|
| 80 |
def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
| 81 |
assert scheme in [
|
| 82 |
"equal-range",
|
| 83 |
+
"equal-mass",
|
| 84 |
], f"This binning scheme {scheme} is not implemented yet"
|
| 85 |
|
| 86 |
if bin_range is None:
|
|
|
|
| 106 |
# rightmost entry per equal size group
|
| 107 |
for cur_group in range(n_bins - 1):
|
| 108 |
bin_upper_edges += [max(groups[cur_group])]
|
| 109 |
+
bin_upper_edges += [1.01] #[np.inf] # always +1 for right edges
|
| 110 |
bins = np.array(bin_upper_edges)
|
| 111 |
+
#OverflowError: cannot convert float infinity to integer
|
| 112 |
|
| 113 |
return bins
|
| 114 |
|
|
|
|
| 202 |
n_bins=kwargs["n_bins"], bin_range=kwargs["bin_range"], scheme=kwargs["scheme"], P=p_max
|
| 203 |
)
|
| 204 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
| 205 |
+
if kwargs["detail"]:
|
| 206 |
return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
|
| 207 |
return CE
|
| 208 |
|