Spaces:
Configuration error
Configuration error
Fix to reliability diagram - correct with test
Browse files
app.py
CHANGED
|
@@ -48,34 +48,31 @@ def reliability_plot(results):
|
|
| 48 |
# DEV: nicer would be to plot like a polygon
|
| 49 |
# see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
|
| 50 |
|
| 51 |
-
def over_under_confidence(results):
|
| 52 |
-
colors = []
|
| 53 |
-
for j, bin in enumerate(results["y_bar"]):
|
| 54 |
-
perfect = results["y_bar"][j]
|
| 55 |
-
empirical = results["p_bar"][j]
|
| 56 |
-
|
| 57 |
-
bin_color = (
|
| 58 |
-
"limegreen"
|
| 59 |
-
if np.allclose(perfect, empirical)
|
| 60 |
-
else "dodgerblue"
|
| 61 |
-
if empirical < perfect
|
| 62 |
-
else "orangered"
|
| 63 |
-
)
|
| 64 |
-
colors.append(bin_color)
|
| 65 |
-
return colors
|
| 66 |
-
|
| 67 |
fig, ax1, ax2 = default_plot()
|
| 68 |
|
| 69 |
# Bin differences
|
| 70 |
bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
ax1handles = [
|
| 81 |
mpatches.Patch(color="orangered", label="Overconfident"),
|
|
@@ -84,12 +81,11 @@ def reliability_plot(results):
|
|
| 84 |
]
|
| 85 |
|
| 86 |
# Bin frequencies
|
| 87 |
-
anindices = np.where(~np.isnan(results["p_bar"]
|
| 88 |
-
|
| 89 |
-
bin_freqs = np.zeros(n_bins)
|
| 90 |
bin_freqs[anindices] = results["bin_freq"]
|
| 91 |
-
|
| 92 |
-
|
| 93 |
)
|
| 94 |
|
| 95 |
acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
|
|
@@ -148,8 +144,8 @@ component = gr.inputs.Dataframe(
|
|
| 148 |
)
|
| 149 |
|
| 150 |
component.value = [
|
| 151 |
-
[[0.
|
| 152 |
-
[[0.
|
| 153 |
[[0, 0.95, 0.05], 1],
|
| 154 |
]
|
| 155 |
sample_data = [[component] + slider_defaults]
|
|
|
|
| 48 |
# DEV: nicer would be to plot like a polygon
|
| 49 |
# see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
fig, ax1, ax2 = default_plot()
|
| 52 |
|
| 53 |
# Bin differences
|
| 54 |
bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
|
| 55 |
+
bins_with_right_edge = np.insert(results["y_bar"], -1, 1.0, axis=0)
|
| 56 |
+
bins_with_leftright_edge = np.insert(bins_with_left_edge, -1, 1.0, axis=0)
|
| 57 |
+
weights = np.nan_to_num(results["p_bar"], copy=True, nan=0)
|
| 58 |
+
|
| 59 |
+
# NOTE: the histogram API is strange
|
| 60 |
+
_, _, patches = ax1.hist(
|
| 61 |
+
bins_with_left_edge,
|
| 62 |
+
weights=weights,
|
| 63 |
+
bins=bins_with_leftright_edge,
|
| 64 |
)
|
| 65 |
+
for b in range(len(patches)):
|
| 66 |
+
perfect = bins_with_right_edge[b] # if b != n_bins else
|
| 67 |
+
empirical = weights[b] # patches[b]._height
|
| 68 |
+
bin_color = (
|
| 69 |
+
"limegreen"
|
| 70 |
+
if perfect == empirical
|
| 71 |
+
else "dodgerblue"
|
| 72 |
+
if empirical < perfect
|
| 73 |
+
else "orangered"
|
| 74 |
+
)
|
| 75 |
+
patches[b].set_facecolor(bin_color) # color based on over/underconfidence
|
| 76 |
|
| 77 |
ax1handles = [
|
| 78 |
mpatches.Patch(color="orangered", label="Overconfident"),
|
|
|
|
| 81 |
]
|
| 82 |
|
| 83 |
# Bin frequencies
|
| 84 |
+
anindices = np.where(~np.isnan(results["p_bar"]))[0]
|
| 85 |
+
bin_freqs = np.zeros(len(results["p_bar"]))
|
|
|
|
| 86 |
bin_freqs[anindices] = results["bin_freq"]
|
| 87 |
+
ax2.hist(
|
| 88 |
+
bins_with_left_edge, weights=bin_freqs, color="midnightblue", bins=bins_with_leftright_edge
|
| 89 |
)
|
| 90 |
|
| 91 |
acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
component.value = [
|
| 147 |
+
[[0.6, 0.2, 0.2], 0],
|
| 148 |
+
[[0.7, 0.1, 0.2], 2],
|
| 149 |
[[0, 0.95, 0.05], 1],
|
| 150 |
]
|
| 151 |
sample_data = [[component] + slider_defaults]
|
ece.py
CHANGED
|
@@ -21,7 +21,6 @@ import numpy as np
|
|
| 21 |
from typing import Dict, Optional
|
| 22 |
|
| 23 |
|
| 24 |
-
|
| 25 |
# TODO: Add BibTeX citation
|
| 26 |
_CITATION = """\
|
| 27 |
@InProceedings{huggingface:module,
|
|
@@ -103,9 +102,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
|
|
| 103 |
# rightmost entry per equal size group
|
| 104 |
for cur_group in range(n_bins - 1):
|
| 105 |
bin_upper_edges += [max(groups[cur_group])]
|
| 106 |
-
bin_upper_edges += [1.01]
|
| 107 |
bins = np.array(bin_upper_edges)
|
| 108 |
-
#OverflowError: cannot convert float infinity to integer
|
| 109 |
|
| 110 |
return bins
|
| 111 |
|
|
@@ -200,7 +199,14 @@ def top_1_CE(Y, P, **kwargs):
|
|
| 200 |
)
|
| 201 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
| 202 |
if kwargs["detail"]:
|
| 203 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
return CE
|
| 205 |
|
| 206 |
|
|
@@ -306,9 +312,18 @@ def test_ECE():
|
|
| 306 |
print(f"ECE: {res['ECE']}")
|
| 307 |
|
| 308 |
res = ECE()._compute(predictions, references, detail=True)
|
| 309 |
-
import pdb; pdb.set_trace() # breakpoint 25274412 //
|
| 310 |
-
|
| 311 |
print(f"ECE: {res['ECE']}")
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from typing import Dict, Optional
|
| 22 |
|
| 23 |
|
|
|
|
| 24 |
# TODO: Add BibTeX citation
|
| 25 |
_CITATION = """\
|
| 26 |
@InProceedings{huggingface:module,
|
|
|
|
| 102 |
# rightmost entry per equal size group
|
| 103 |
for cur_group in range(n_bins - 1):
|
| 104 |
bin_upper_edges += [max(groups[cur_group])]
|
| 105 |
+
bin_upper_edges += [1.01] # [np.inf] # always +1 for right edges
|
| 106 |
bins = np.array(bin_upper_edges)
|
| 107 |
+
# OverflowError: cannot convert float infinity to integer
|
| 108 |
|
| 109 |
return bins
|
| 110 |
|
|
|
|
| 199 |
)
|
| 200 |
CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
|
| 201 |
if kwargs["detail"]:
|
| 202 |
+
return {
|
| 203 |
+
"ECE": CE[0],
|
| 204 |
+
"y_bar": CE[1],
|
| 205 |
+
"p_bar": CE[2],
|
| 206 |
+
"bin_freq": CE[3],
|
| 207 |
+
"p_bar_cont": np.mean(p_max, -1),
|
| 208 |
+
"accuracy": np.mean(y_correct),
|
| 209 |
+
}
|
| 210 |
return CE
|
| 211 |
|
| 212 |
|
|
|
|
| 312 |
print(f"ECE: {res['ECE']}")
|
| 313 |
|
| 314 |
res = ECE()._compute(predictions, references, detail=True)
|
|
|
|
|
|
|
| 315 |
print(f"ECE: {res['ECE']}")
|
| 316 |
|
| 317 |
+
|
| 318 |
+
def test_deterministic():
|
| 319 |
+
res = ECE()._compute(
|
| 320 |
+
references=[0, 1, 2],
|
| 321 |
+
predictions=[[0.63, 0.2, 0.2], [0, 0.95, 0.05], [0.72, 0.1, 0.2]],
|
| 322 |
+
detail=True,
|
| 323 |
+
)
|
| 324 |
+
print(f"ECE: {res['ECE']}\n {res}")
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
test_deterministic()
|
| 329 |
+
test_ECE()
|