| import math | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from itertools import permutations | |
| from ..tabular import TabularDataBunch | |
| from ..train import ClassificationInterpretation | |
| import ipywidgets as widgets | |
| class ClassConfusion(): | |
| "Plot the most confused datapoints and statistics for the models misses." | |
| def __init__(self, interp:ClassificationInterpretation, classlist:list, | |
| is_ordered:bool=False, cut_off:int=100, varlist:list=None, | |
| figsize:tuple=(8,8)): | |
| self.interp = interp | |
| self._is_tab = isinstance(interp.learn.data, TabularDataBunch) | |
| if self._is_tab: | |
| if interp.learn.data.train_ds.x.cont_names != []: | |
| for x in range(len(interp.learn.data.procs)): | |
| if "Normalize" in str(interp.learn.data.procs[x]): | |
| self.means = interp.learn.data.train_ds.x.processor[0].procs[x].means | |
| self.stds = interp.learn.data.train_ds.x.processor[0].procs[x].stds | |
| self.is_ordered = is_ordered | |
| self.cut_off = cut_off | |
| self.figsize = figsize | |
| self.varlist = varlist | |
| self.classl = classlist | |
| self._show_losses(classlist) | |
| def _show_losses(self, classl:list, **kwargs): | |
| "Checks if the model is for Tabular or Images and gathers top losses" | |
| _, self.tl_idx = self.interp.top_losses(len(self.interp.losses)) | |
| self._tab_losses() if self._is_tab else self._create_tabs() | |
| def _create_tabs(self): | |
| "Creates a tab for each variable" | |
| self.lis = self.classl if self.is_ordered else list(permutations(self.classl, 2)) | |
| if self._is_tab: | |
| self._boxes = len(self.df_list) | |
| self._cols = math.ceil(math.sqrt(self._boxes)) | |
| self._rows = math.ceil(self._boxes/self._cols) | |
| self.tbnames = list(self.df_list[0].columns)[:-1] if self.varlist is None else self.varlist | |
| else: | |
| vals = self.interp.most_confused() | |
| self._ranges = [] | |
| self.tbnames = [] | |
| self._boxes = int(input('Please enter a value for `k`, or the top images you will see: ')) | |
| for x in iter(vals): | |
| for y in range(len(self.lis)): | |
| if x[0:2] == self.lis[y]: | |
| self._ranges.append(x[2]) | |
| self.tbnames.append(str(x[0] + ' | ' + x[1])) | |
| items = [widgets.Output() for i, tab in enumerate(self.tbnames)] | |
| self.tabs = widgets.Tab() | |
| self.tabs.children = items | |
| for i in range(len(items)): | |
| self.tabs.set_title(i, self.tbnames[i]) | |
| self._populate_tabs() | |
| def _populate_tabs(self): | |
| "Adds relevant graphs to each tab" | |
| with tqdm(total=len(self.tbnames)) as pbar: | |
| for i, tab in enumerate(self.tbnames): | |
| with self.tabs.children[i]: | |
| self._plot_tab(tab) if self._is_tab else self._plot_imgs(tab, i) | |
| pbar.update(1) | |
| display(self.tabs) | |
| def _plot_tab(self, tab:str): | |
| "Generates graphs" | |
| if self._boxes is not None: | |
| fig, ax = plt.subplots(self._boxes, figsize=self.figsize) | |
| else: | |
| fig, ax = plt.subplots(self._cols, self._rows, figsize=self.figsize) | |
| fig.subplots_adjust(hspace=.5) | |
| for j, x in enumerate(self.df_list): | |
| title = f'{"".join(x.columns[-1])} {tab} distribution' | |
| if self._boxes is None: | |
| row = int(j / self._cols) | |
| col = j % row | |
| if tab in self.cat_names: | |
| vals = pd.value_counts(x[tab].values) | |
| if self._boxes is not None: | |
| if vals.nunique() < 10: | |
| fig = vals.plot(kind='bar', title=title, ax=ax[j], rot=0, width=.75) | |
| elif vals.nunique() > self.cut_off: | |
| print(f'Number of values is above {self.cut_off}') | |
| else: | |
| fig = vals.plot(kind='barh', title=title, ax=ax[j], width=.75) | |
| else: | |
| fig = vals.plot(kind='barh', title=title, ax=ax[row, col], width=.75) | |
| else: | |
| vals = x[tab] | |
| if self._boxes is not None: | |
| axs = vals.plot(kind='hist', ax=ax[j], title=title, y='Frequency') | |
| else: | |
| axs = vals.plot(kind='hist', ax=ax[row, col], title=title, y='Frequency') | |
| axs.set_ylabel('Frequency') | |
| if len(set(vals)) > 1: | |
| vals.plot(kind='kde', ax=axs, title=title, secondary_y=True) | |
| else: | |
| print('Less than two unique values, cannot graph the KDE') | |
| plt.show(fig) | |
| plt.tight_layout() | |
| def _plot_imgs(self, tab:str, i:int ,**kwargs): | |
| "Plots the most confused images" | |
| classes_gnd = self.interp.data.classes | |
| x = 0 | |
| if self._ranges[i] < self._boxes: | |
| cols = math.ceil(math.sqrt(self._ranges[i])) | |
| rows = math.ceil(self._ranges[i]/cols) | |
| if self._ranges[i] < 4 or self._boxes < 4: | |
| cols = 2 | |
| rows = 2 | |
| else: | |
| cols = math.ceil(math.sqrt(self._boxes)) | |
| rows = math.ceil(self._boxes/cols) | |
| fig, ax = plt.subplots(rows, cols, figsize=self.figsize) | |
| [axi.set_axis_off() for axi in ax.ravel()] | |
| for j, idx in enumerate(self.tl_idx): | |
| if self._boxes < x+1 or x > self._ranges[i]: | |
| break | |
| da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] | |
| row = (int)(x / cols) | |
| col = x % cols | |
| if str(cl) == tab.split(' ')[0] and str(classes_gnd[self.interp.pred_class[idx]]) == tab.split(' ')[2]: | |
| img, lbl = self.interp.data.valid_ds[idx] | |
| fn = self.interp.data.valid_ds.x.items[idx] | |
| fn = re.search('([^/*]+)_\d+.*$', str(fn)).group(0) | |
| img.show(ax=ax[row, col]) | |
| ax[row,col].set_title(fn) | |
| x += 1 | |
| plt.show(fig) | |
| plt.tight_layout() | |
| def _tab_losses(self, **kwargs): | |
| "Gathers dataframes of the combinations data" | |
| classes = self.interp.data.classes | |
| cat_names = self.interp.data.x.cat_names | |
| cont_names = self.interp.data.x.cont_names | |
| comb = self.classl if self.is_ordered else list(permutations(self.classl,2)) | |
| self.df_list = [] | |
| arr = [] | |
| for i, idx in enumerate(self.tl_idx): | |
| da, _ = self.interp.data.dl(self.interp.ds_type).dataset[idx] | |
| res = '' | |
| for c, n in zip(da.cats, da.names[:len(da.cats)]): | |
| string = f'{da.classes[n][c]}' | |
| if string == 'True' or string == 'False': | |
| string += ';' | |
| res += string | |
| else: | |
| string = string[1:] | |
| res += string + ';' | |
| for c, n in zip(da.conts, da.names[len(da.cats):]): | |
| res += f'{c:.4f};' | |
| arr.append(res) | |
| f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) | |
| for i, var in enumerate(self.interp.data.cont_names): | |
| f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) | |
| f['Original'] = 'Original' | |
| self.df_list.append(f) | |
| for j, x in enumerate(comb): | |
| arr = [] | |
| for i, idx in enumerate(self.tl_idx): | |
| da, cl = self.interp.data.dl(self.interp.ds_type).dataset[idx] | |
| cl = int(cl) | |
| if classes[self.interp.pred_class[idx]] == comb[j][0] and classes[cl] == comb[j][1]: | |
| res = '' | |
| for c, n in zip(da.cats, da.names[:len(da.cats)]): | |
| string = f'{da.classes[n][c]}' | |
| if string == 'True' or string == 'False': | |
| string += ';' | |
| res += string | |
| else: | |
| string = string[1:] | |
| res += string + ';' | |
| for c, n in zip(da.conts, da.names[len(da.cats):]): | |
| res += f'{c:.4f};' | |
| arr.append(res) | |
| f = pd.DataFrame([ x.split(';')[:-1] for x in arr], columns=da.names) | |
| for i, var in enumerate(self.interp.data.cont_names): | |
| f[var] = f[var].apply(lambda x: float(x) * self.stds[var] + self.means[var]) | |
| f[str(x)] = str(x) | |
| self.df_list.append(f) | |
| self.cat_names = cat_names | |
| self._create_tabs() | |