File size: 10,884 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from ..torch_core import *
from ..basic_train import *
from ..basic_data import *
from ..vision.data import *
from ..vision.transform import *
from ..vision.image import *
from ..callbacks.hooks import *
from ..layers import *
from ipywidgets import widgets, Layout
from IPython.display import clear_output, display

__all__ = ['DatasetFormatter', 'ImageCleaner']

class DatasetFormatter():
    "Returns a dataset with the appropriate format and file indices to be displayed."
    @classmethod
    def from_toplosses(cls, learn, n_imgs=None, **kwargs):
        "Gets indices with top losses."
        train_ds, train_idxs = cls.get_toplosses_idxs(learn, n_imgs, **kwargs)
        return train_ds, train_idxs

    @classmethod
    def get_toplosses_idxs(cls, learn, n_imgs, **kwargs):
        "Sorts `ds_type` dataset by top losses and returns dataset and sorted indices."
        dl = learn.data.fix_dl
        if not n_imgs: n_imgs = len(dl.dataset)
        _,_,top_losses = learn.get_preds(ds_type=DatasetType.Fix, with_loss=True)
        idxs = torch.topk(top_losses, n_imgs)[1]
        return cls.padded_ds(dl.dataset, **kwargs), idxs

    def padded_ds(ll_input, size=(250, 300), resize_method=ResizeMethod.CROP, padding_mode='zeros', **kwargs):
        "For a LabelList `ll_input`, resize each image to `size` using `resize_method` and `padding_mode`."
        return ll_input.transform(tfms=crop_pad(), size=size, resize_method=resize_method, padding_mode=padding_mode)
    
    @classmethod
    def from_similars(cls, learn, layer_ls:list=[0, 7, 2], **kwargs):
        "Gets the indices for the most similar images."
        train_ds, train_idxs = cls.get_similars_idxs(learn, layer_ls, **kwargs)
        return train_ds, train_idxs

    @classmethod
    def get_similars_idxs(cls, learn, layer_ls, **kwargs):
        "Gets the indices for the most similar images in `ds_type` dataset"
        hook = hook_output(learn.model[layer_ls[0]][layer_ls[1]][layer_ls[2]])
        dl = learn.data.fix_dl

        ds_actns = cls.get_actns(learn, hook=hook, dl=dl, **kwargs)
        similarities = cls.comb_similarity(ds_actns, ds_actns, **kwargs)
        idxs = cls.sort_idxs(similarities)
        return cls.padded_ds(dl, **kwargs), idxs

    @staticmethod
    def get_actns(learn, hook:Hook, dl:DataLoader, pool=AdaptiveConcatPool2d, pool_dim:int=4, **kwargs):
        "Gets activations at the layer specified by `hook`, applies `pool` of dim `pool_dim` and concatenates"
        print('Getting activations...')

        actns = []
        learn.model.eval()
        with torch.no_grad():
            for (xb,yb) in progress_bar(dl):
                learn.model(xb)
                actns.append((hook.stored).cpu())

        if pool:
            pool = pool(pool_dim)
            return pool(torch.cat(actns)).view(len(dl.x),-1)
        else: return torch.cat(actns).view(len(dl.x),-1)


    @staticmethod
    def comb_similarity(t1: torch.Tensor, t2: torch.Tensor, **kwargs):
        # https://github.com/pytorch/pytorch/issues/11202
        "Computes the similarity function between each embedding of `t1` and `t2` matrices."
        print('Computing similarities...')

        w1 = t1.norm(p=2, dim=1, keepdim=True)
        w2 = w1 if t2 is t1 else t2.norm(p=2, dim=1, keepdim=True)

        t = torch.mm(t1, t2.t()) / (w1 * w2.t()).clamp(min=1e-8)
        return torch.tril(t, diagonal=-1) 

    def largest_indices(arr, n):
        "Returns the `n` largest indices from a numpy array `arr`."
        #https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array
        flat = arr.flatten()
        indices = np.argpartition(flat, -n)[-n:]
        indices = indices[np.argsort(-flat[indices])]
        return np.unravel_index(indices, arr.shape)

    @classmethod
    def sort_idxs(cls, similarities):
        "Sorts `similarities` and return the indexes in pairs ordered by highest similarity."
        idxs = cls.largest_indices(similarities, len(similarities))
        idxs = [(idxs[0][i], idxs[1][i]) for i in range(len(idxs[0]))]
        return [e for l in idxs for e in l]

class ImageCleaner():
    "Displays images for relabeling or deletion and saves changes in `path` as 'cleaned.csv'."
    def __init__(self, dataset, fns_idxs, path, batch_size:int=5, duplicates=False):
        self._all_images,self._batch = [],[]
        self._path = Path(path)
        self._batch_size = batch_size
        if duplicates: self._batch_size = 2
        self._duplicates = duplicates
        self._labels = dataset.classes
        self._all_images = self.create_image_list(dataset, fns_idxs)
        self._csv_dict = {dataset.x.items[i]: dataset.y[i] for i in range(len(dataset))}
        self._deleted_fns = []
        self._skipped = 0
        self.render()

    @classmethod
    def make_img_widget(cls, img, layout=Layout(), format='jpg'):
        "Returns an image widget for specified file name `img`."
        return widgets.Image(value=img, format=format, layout=layout)

    @classmethod
    def make_button_widget(cls, label, file_path=None, handler=None, style=None, layout=Layout(width='auto')):
        "Return a Button widget with specified `handler`."
        btn = widgets.Button(description=label, layout=layout)
        if handler is not None: btn.on_click(handler)
        if style is not None: btn.button_style = style
        btn.file_path = file_path
        btn.flagged_for_delete = False
        return btn

    @classmethod
    def make_dropdown_widget(cls, description='Description', options=['Label 1', 'Label 2'], value='Label 1',
                            file_path=None, layout=Layout(), handler=None):
        "Return a Dropdown widget with specified `handler`."
        dd = widgets.Dropdown(description=description, options=options, value=value, layout=layout)
        if file_path is not None: dd.file_path = file_path
        if handler is not None: dd.observe(handler, names=['value'])
        return dd

    @classmethod
    def make_horizontal_box(cls, children, layout=Layout()):
        "Make a horizontal box with `children` and `layout`."
        return widgets.HBox(children, layout=layout)

    @classmethod
    def make_vertical_box(cls, children, layout=Layout(), duplicates=False):
        "Make a vertical box with `children` and `layout`."
        if not duplicates: return widgets.VBox(children, layout=layout)
        else: return widgets.VBox([children[0], children[2]], layout=layout)

    def create_image_list(self, dataset, fns_idxs):
        "Create a list of images, filenames and labels but first removing files that are not supposed to be displayed."
        items = dataset.x.items
        if self._duplicates:
            chunked_idxs = chunks(fns_idxs, 2)
            chunked_idxs = [chunk for chunk in chunked_idxs if Path(items[chunk[0]]).is_file() and Path(items[chunk[1]]).is_file()]
            return  [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for chunk in chunked_idxs for i in chunk]
        else:
            return [(dataset.x[i]._repr_jpeg_(), items[i], self._labels[dataset.y[i].data]) for i in fns_idxs if
                    Path(items[i]).is_file()]

    def relabel(self, change):
        "Relabel images by moving from parent dir with old label `class_old` to parent dir with new label `class_new`."
        class_new,class_old,file_path = change.new,change.old,change.owner.file_path
        fp = Path(file_path)
        parent = fp.parents[1]
        self._csv_dict[fp] = class_new

    def next_batch(self, _):
        "Handler for 'Next Batch' button click. Delete all flagged images and renders next batch."
        for img_widget, delete_btn, fp, in self._batch:
            fp = delete_btn.file_path
            if (delete_btn.flagged_for_delete == True):
                self.delete_image(fp)
                self._deleted_fns.append(fp)
        self._all_images = self._all_images[self._batch_size:]
        self.empty_batch()
        self.render()

    def on_delete(self, btn):
        "Flag this image as delete or keep."
        btn.button_style = "" if btn.flagged_for_delete else "danger"
        btn.flagged_for_delete = not btn.flagged_for_delete

    def empty_batch(self): self._batch[:] = []

    def delete_image(self, file_path):
        del self._csv_dict[file_path]

    def empty(self):
        return len(self._all_images) == 0

    def get_widgets(self, duplicates):
        "Create and format widget set."
        widgets = []
        for (img,fp,human_readable_label) in self._all_images[:self._batch_size]:
            img_widget = self.make_img_widget(img, layout=Layout(height='250px', width='300px'))
            dropdown = self.make_dropdown_widget(description='', options=self._labels, value=human_readable_label,
                                                 file_path=fp, handler=self.relabel, layout=Layout(width='auto'))
            delete_btn = self.make_button_widget('Delete', file_path=fp, handler=self.on_delete)
            widgets.append(self.make_vertical_box([img_widget, dropdown, delete_btn],
                                                  layout=Layout(width='auto', height='300px',
                                                      overflow_x="hidden"), duplicates=duplicates))
            self._batch.append((img_widget, delete_btn, fp))
        return widgets

    def batch_contains_deleted(self):
        "Check if current batch contains already deleted images."
        if not self._duplicates: return False
        imgs = [self._all_images[:self._batch_size][0][1], self._all_images[:self._batch_size][1][1]]
        return any(img in self._deleted_fns for img in imgs)

    def write_csv(self):
        # Get first element's file path so we write CSV to same directory as our data
        csv_path = self._path/'cleaned.csv'
        with open(csv_path, 'w') as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(['name','label'])
            for pair in self._csv_dict.items():
                pair = [os.path.relpath(pair[0], self._path), pair[1]]
                csv_writer.writerow(pair)
        return csv_path

    def render(self):
        "Re-render Jupyter cell for batch of images."
        clear_output()
        self.write_csv()
        if self.empty() and self._skipped>0:
            return display(f'No images to show :). {self._skipped} pairs were '
                    f'skipped since at least one of the images was deleted by the user.')
        elif self.empty():
            return display('No images to show :)')
        if self.batch_contains_deleted():
            self.next_batch(None)
            self._skipped += 1
        else:
            display(self.make_horizontal_box(self.get_widgets(self._duplicates)))
            display(self.make_button_widget('Next Batch', handler=self.next_batch, style="primary"))