import math import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm from itertools import permutations from ..tabular import TabularDataBunch from ..train import ClassificationInterpretation import ipywidgets as widgets class ClassConfusion(): "Plot the most confused datapoints and statistics for the models misses." def __init__(self, interp:ClassificationInterpretation, classlist:list, is_ordered:bool=False, cut_off:int=100, varlist:list=None, figsize:tuple=(8,8)): self.interp = interp self._is_tab = isinstance(interp.learn.data, TabularDataBunch) if self._is_tab: if interp.learn.data.train_ds.x.cont_names != []: for x in range(len(interp.learn.data.procs)): if "Normalize" in str(interp.learn.data.procs[x]): self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds self.is_ordered = is_ordered self.cut_off = cut_off self.figsize = figsize self.varlist = varlist self.classl = classlist self._show_losses(classlist) def _show_losses(self, classl:list, **kwargs): "Checks if the model is for Tabular or Images and gathers top losses" _, self.tl_idx = self.interp.top_losses(len(self.interp.losses)) self._tab_losses() if self._is_tab else self._create_tabs() def _create_tabs(self): "Creates a tab for each variable" self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2)) if self._is_tab: self._boxes = len(self.df_list) self._cols = math.ceil(math.sqrt(self._boxes)) self._rows = math.ceil(self._boxes/self._cols) self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist else: vals = self.interp.most_confused() self._ranges = [] self.tbnames = [] self._boxes = int(input('Please enter a value for `k`, or the top images you will see: ')) for x in iter(vals): for y in range(len(self.lis)): if x[0:2] == self.lis[y]: self._ranges.append(x[2]) self.tbnames.append(str(x[0] + ' | ' + x[1])) items = [widgets.Output() for i, tab in enumerate(self.tbnames)] self.tabs = widgets.Tab() self.tabs.children = items for i in range(len(items)): self.tabs.set_title(i, self.tbnames[i]) self._populate_tabs() def _populate_tabs(self): "Adds relevant graphs to each tab" with tqdm(total=len(self.tbnames)) as pbar: for i, tab in enumerate(self.tbnames): with self.tabs.children[i]: self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i) pbar.update(1) display(self.tabs) def _plot_tab(self, tab:str): "Generates graphs" if self._boxes is not None: fig, ax = plt.subplots(self._boxes, figsize=self.figsize) else: fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize) fig.subplots_adjust(hspace=.5) for j, x in enumerate(self.df_list): title = f'{"".join(x.columns[-1])} {tab} distribution' if self._boxes is None: row = int(j / self._cols) col = j % row if tab in self.cat_names: vals = pd.value_counts(x[tab].values) if self._boxes is not None: if vals.nunique() < 10: fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75) elif vals.nunique() > self.cut_off: print(f'Number of values is above {self.cut_off}') else: fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75) else: fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75) else: vals = x[tab] if self._boxes is not None: axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency') else: axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency') axs.set_ylabel('Frequency') if len(set(vals)) > 1: vals.plot(kind='kde', ax=axs, title=title, secondary_y=True) else: print('Less than two unique values, cannot graph the KDE') plt.show(fig) plt.tight_layout() def _plot_imgs(self, tab:str, i:int ,**kwargs): "Plots the most confused images" classes_gnd = self.interp.data.classes x = 0 if self._ranges[i] < self._boxes: cols = math.ceil(math.sqrt(self._ranges[i])) rows = math.ceil(self._ranges[i]/cols) if self._ranges[i] < 4 or self._boxes < 4: cols = 2 rows = 2 else: cols = math.ceil(math.sqrt(self._boxes)) rows = math.ceil(self._boxes/cols) fig, ax = plt.subplots(rows, cols, figsize=self.figsize) [axi.set_axis_off() for axi in ax.ravel()] for j, idx in enumerate(self.tl_idx): if self._boxes < x+1 or x > self._ranges[i]: break da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] row = (int)(x / cols) col = x % cols if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]: img, lbl = self.interp.data.valid_ds[idx] fn = self.interp.data.valid_ds.x.items[idx] fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0) img.show(ax=ax[row, col]) ax[row,col].set_title(fn) x += 1 plt.show(fig) plt.tight_layout() def _tab_losses(self, **kwargs): "Gathers dataframes of the combinations data" classes = self.interp.data.classes cat_names = self.interp.data.x.cat_names cont_names = self.interp.data.x.cont_names comb = self.classl if self.is_ordered else list(permutations(self.classl,2)) self.df_list = [] arr = [] for i, idx in enumerate(self.tl_idx): da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx] res = '' for c, n in zip(da.cats, da.names[:len(da.cats)]): string = f'{da.classes[n][c]}' if string == 'True' or string == 'False': string += ';' res += string else: string = string[1:] res += string + ';' for c, n in zip(da.conts, da.names[len(da.cats):]): res += f'{c:.4f};' arr.append(res) f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) for i, var in enumerate(self.interp.data.cont_names): f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) f['Original'] = 'Original' self.df_list.append(f) for j, x in enumerate(comb): arr = [] for i, idx in enumerate(self.tl_idx): da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] cl = int(cl) if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]: res = '' for c, n in zip(da.cats, da.names[:len(da.cats)]): string = f'{da.classes[n][c]}' if string == 'True' or string == 'False': string += ';' res += string else: string = string[1:] res += string + ';' for c, n in zip(da.conts, da.names[len(da.cats):]): res += f'{c:.4f};' arr.append(res) f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) for i, var in enumerate(self.interp.data.cont_names): f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) f[str(x)] = str(x) self.df_list.append(f) self.cat_names = cat_names self._create_tabs()