File size: 8,760 Bytes
e9f9fd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import math
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import permutations
from ..tabular import TabularDataBunch
from ..train import ClassificationInterpretation
import ipywidgets as widgets
class ClassConfusion():
"Plot the most confused datapoints and statistics for the models misses."
def __init__(self, interp:ClassificationInterpretation, classlist:list,
is_ordered:bool=False, cut_off:int=100, varlist:list=None,
figsize:tuple=(8,8)):
self.interp = interp
self._is_tab = isinstance(interp.learn.data, TabularDataBunch)
if self._is_tab:
if interp.learn.data.train_ds.x.cont_names != []:
for x in range(len(interp.learn.data.procs)):
if "Normalize" in str(interp.learn.data.procs[x]):
self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means
self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds
self.is_ordered = is_ordered
self.cut_off = cut_off
self.figsize = figsize
self.varlist = varlist
self.classl = classlist
self._show_losses(classlist)
def _show_losses(self, classl:list, **kwargs):
"Checks if the model is for Tabular or Images and gathers top losses"
_, self.tl_idx = self.interp.top_losses(len(self.interp.losses))
self._tab_losses() if self._is_tab else self._create_tabs()
def _create_tabs(self):
"Creates a tab for each variable"
self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2))
if self._is_tab:
self._boxes = len(self.df_list)
self._cols = math.ceil(math.sqrt(self._boxes))
self._rows = math.ceil(self._boxes/self._cols)
self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist
else:
vals = self.interp.most_confused()
self._ranges = []
self.tbnames = []
self._boxes = int(input('Please enter a value for `k`, or the top images you will see: '))
for x in iter(vals):
for y in range(len(self.lis)):
if x[0:2] == self.lis[y]:
self._ranges.append(x[2])
self.tbnames.append(str(x[0] + ' | ' + x[1]))
items = [widgets.Output() for i, tab in enumerate(self.tbnames)]
self.tabs = widgets.Tab()
self.tabs.children = items
for i in range(len(items)):
self.tabs.set_title(i, self.tbnames[i])
self._populate_tabs()
def _populate_tabs(self):
"Adds relevant graphs to each tab"
with tqdm(total=len(self.tbnames)) as pbar:
for i, tab in enumerate(self.tbnames):
with self.tabs.children[i]:
self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i)
pbar.update(1)
display(self.tabs)
def _plot_tab(self, tab:str):
"Generates graphs"
if self._boxes is not None:
fig, ax = plt.subplots(self._boxes, figsize=self.figsize)
else:
fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize)
fig.subplots_adjust(hspace=.5)
for j, x in enumerate(self.df_list):
title = f'{"".join(x.columns[-1])} {tab} distribution'
if self._boxes is None:
row = int(j / self._cols)
col = j % row
if tab in self.cat_names:
vals = pd.value_counts(x[tab].values)
if self._boxes is not None:
if vals.nunique() < 10:
fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75)
elif vals.nunique() > self.cut_off:
print(f'Number of values is above {self.cut_off}')
else:
fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75)
else:
fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75)
else:
vals = x[tab]
if self._boxes is not None:
axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency')
else:
axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency')
axs.set_ylabel('Frequency')
if len(set(vals)) > 1:
vals.plot(kind='kde', ax=axs, title=title, secondary_y=True)
else:
print('Less than two unique values, cannot graph the KDE')
plt.show(fig)
plt.tight_layout()
def _plot_imgs(self, tab:str, i:int ,**kwargs):
"Plots the most confused images"
classes_gnd = self.interp.data.classes
x = 0
if self._ranges[i] < self._boxes:
cols = math.ceil(math.sqrt(self._ranges[i]))
rows = math.ceil(self._ranges[i]/cols)
if self._ranges[i] < 4 or self._boxes < 4:
cols = 2
rows = 2
else:
cols = math.ceil(math.sqrt(self._boxes))
rows = math.ceil(self._boxes/cols)
fig, ax = plt.subplots(rows, cols, figsize=self.figsize)
[axi.set_axis_off() for axi in ax.ravel()]
for j, idx in enumerate(self.tl_idx):
if self._boxes < x+1 or x > self._ranges[i]:
break
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
row = (int)(x / cols)
col = x % cols
if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]:
img, lbl = self.interp.data.valid_ds[idx]
fn = self.interp.data.valid_ds.x.items[idx]
fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0)
img.show(ax=ax[row, col])
ax[row,col].set_title(fn)
x += 1
plt.show(fig)
plt.tight_layout()
def _tab_losses(self, **kwargs):
"Gathers dataframes of the combinations data"
classes = self.interp.data.classes
cat_names = self.interp.data.x.cat_names
cont_names = self.interp.data.x.cont_names
comb = self.classl if self.is_ordered else list(permutations(self.classl,2))
self.df_list = []
arr = []
for i, idx in enumerate(self.tl_idx):
da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx]
res = ''
for c, n in zip(da.cats, da.names[:len(da.cats)]):
string = f'{da.classes[n][c]}'
if string == 'True' or string == 'False':
string += ';'
res += string
else:
string = string[1:]
res += string + ';'
for c, n in zip(da.conts, da.names[len(da.cats):]):
res += f'{c:.4f};'
arr.append(res)
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
for i, var in enumerate(self.interp.data.cont_names):
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
f['Original'] = 'Original'
self.df_list.append(f)
for j, x in enumerate(comb):
arr = []
for i, idx in enumerate(self.tl_idx):
da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx]
cl = int(cl)
if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]:
res = ''
for c, n in zip(da.cats, da.names[:len(da.cats)]):
string = f'{da.classes[n][c]}'
if string == 'True' or string == 'False':
string += ';'
res += string
else:
string = string[1:]
res += string + ';'
for c, n in zip(da.conts, da.names[len(da.cats):]):
res += f'{c:.4f};'
arr.append(res)
f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names)
for i, var in enumerate(self.interp.data.cont_names):
f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var])
f[str(x)] = str(x)
self.df_list.append(f)
self.cat_names = cat_names
self._create_tabs()
|