File size: 9,089 Bytes
0e83290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""

Contains all the types of interaction related to the GUI

Not related to automatic evaluation in the DAVIS dataset



You can inherit the Interaction class to create new interaction types

undo is (sometimes partially) supported

"""


import torch
import torch.nn.functional as F
import numpy as np
import cv2
import time
from .interactive_utils import color_map, index_numpy_to_one_hot_torch


def aggregate_sbg(prob, keep_bg=False, hard=False):
    device = prob.device
    k, h, w = prob.shape
    ex_prob = torch.zeros((k+1, h, w), device=device)
    ex_prob[0] = 0.5
    ex_prob[1:] = prob
    ex_prob = torch.clamp(ex_prob, 1e-7, 1-1e-7)
    logits = torch.log((ex_prob /(1-ex_prob)))

    if hard:
        # Very low temperature o((⊙﹏⊙))o 🥶
        logits *= 1000

    if keep_bg:
        return F.softmax(logits, dim=0)
    else:
        return F.softmax(logits, dim=0)[1:]

def aggregate_wbg(prob, keep_bg=False, hard=False):
    k, h, w = prob.shape
    new_prob = torch.cat([
        torch.prod(1-prob, dim=0, keepdim=True),
        prob
    ], 0).clamp(1e-7, 1-1e-7)
    logits = torch.log((new_prob /(1-new_prob)))

    if hard:
        # Very low temperature o((⊙﹏⊙))o 🥶
        logits *= 1000

    if keep_bg:
        return F.softmax(logits, dim=0)
    else:
        return F.softmax(logits, dim=0)[1:]

class Interaction:
    def __init__(self, image, prev_mask, true_size, controller):
        self.image = image 
        self.prev_mask = prev_mask
        self.controller = controller
        self.start_time = time.time()

        self.h, self.w = true_size

        self.out_prob = None
        self.out_mask = None

    def predict(self):
        pass


class FreeInteraction(Interaction):
    def __init__(self, image, prev_mask, true_size, num_objects):
        """

        prev_mask should be index format numpy array

        """
        super().__init__(image, prev_mask, true_size, None)

        self.K = num_objects

        self.drawn_map = self.prev_mask.copy()
        self.curr_path = [[] for _ in range(self.K + 1)]

        self.size = None

    def set_size(self, size):
        self.size = size

    """

    k - object id

    vis - a tuple (visualization map, pass through alpha). None if not needed.

    """
    def push_point(self, x, y, k, vis=None):
        if vis is not None:
            vis_map, vis_alpha = vis
        selected = self.curr_path[k]
        selected.append((x, y))
        if len(selected) >= 2:
            cv2.line(self.drawn_map, 
                (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                k, thickness=self.size)

            # Plot visualization
            if vis is not None:
                # Visualization for drawing
                if k == 0:
                    vis_map = cv2.line(vis_map, 
                        (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                        (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                        color_map[k], thickness=self.size)
                else:
                    vis_map = cv2.line(vis_map, 
                        (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                        (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                        color_map[k], thickness=self.size)
                # Visualization on/off boolean filter
                vis_alpha = cv2.line(vis_alpha, 
                    (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                    (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                    0.75, thickness=self.size)

        if vis is not None:
            return vis_map, vis_alpha

    def end_path(self):
        # Complete the drawing
        self.curr_path = [[] for _ in range(self.K + 1)]

    def predict(self):
        self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
        # self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
        # self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
        # self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
        return self.out_prob

class ScribbleInteraction(Interaction):
    def __init__(self, image, prev_mask, true_size, controller, num_objects):
        """

        prev_mask should be in an indexed form

        """
        super().__init__(image, prev_mask, true_size, controller)

        self.K = num_objects

        self.drawn_map = np.empty((self.h, self.w), dtype=np.uint8)
        self.drawn_map.fill(255)
        # background + k
        self.curr_path = [[] for _ in range(self.K + 1)]
        self.size = 3

    """

    k - object id

    vis - a tuple (visualization map, pass through alpha). None if not needed.

    """
    def push_point(self, x, y, k, vis=None):
        if vis is not None:
            vis_map, vis_alpha = vis
        selected = self.curr_path[k]
        selected.append((x, y))
        if len(selected) >= 2:
            self.drawn_map = cv2.line(self.drawn_map, 
                (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                k, thickness=self.size)

            # Plot visualization
            if vis is not None:
                # Visualization for drawing
                if k == 0:
                    vis_map = cv2.line(vis_map, 
                        (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                        (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                        color_map[k], thickness=self.size)
                else:
                    vis_map = cv2.line(vis_map, 
                            (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                            (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                            color_map[k], thickness=self.size)
                # Visualization on/off boolean filter
                vis_alpha = cv2.line(vis_alpha, 
                        (int(round(selected[-2][0])), int(round(selected[-2][1]))),
                        (int(round(selected[-1][0])), int(round(selected[-1][1]))),
                        0.75, thickness=self.size)

        # Optional vis return
        if vis is not None:
            return vis_map, vis_alpha

    def end_path(self):
        # Complete the drawing
        self.curr_path = [[] for _ in range(self.K + 1)]

    def predict(self):
        self.out_prob = self.controller.interact(self.image.unsqueeze(0), self.prev_mask, self.drawn_map)
        self.out_prob = aggregate_wbg(self.out_prob, keep_bg=True, hard=True)
        return self.out_prob


class ClickInteraction(Interaction):
    def __init__(self, image, prev_mask, true_size, controller, tar_obj):
        """

        prev_mask in a prob. form

        """
        super().__init__(image, prev_mask, true_size, controller)
        self.tar_obj = tar_obj

        # negative/positive for each object
        self.pos_clicks = []
        self.neg_clicks = []

        self.out_prob = self.prev_mask.clone()

    """

    neg - Negative interaction or not

    vis - a tuple (visualization map, pass through alpha). None if not needed.

    """
    def push_point(self, x, y, neg, vis=None):
        # Clicks
        if neg:
            self.neg_clicks.append((x, y))
        else:
            self.pos_clicks.append((x, y))

        # Do the prediction
        self.obj_mask = self.controller.interact(self.image.unsqueeze(0), x, y, not neg)

        # Plot visualization
        if vis is not None:
            vis_map, vis_alpha = vis
            # Visualization for clicks
            if neg:
                vis_map = cv2.circle(vis_map, 
                        (int(round(x)), int(round(y))),
                        2, color_map[0], thickness=-1)
            else:
                vis_map = cv2.circle(vis_map, 
                        (int(round(x)), int(round(y))),
                        2, color_map[self.tar_obj], thickness=-1)

            vis_alpha = cv2.circle(vis_alpha, 
                        (int(round(x)), int(round(y))),
                        2, 1, thickness=-1)

            # Optional vis return
            return vis_map, vis_alpha

    def predict(self):
        self.out_prob = self.prev_mask.clone()
        # a small hack to allow the interacting object to overwrite existing masks
        # without remembering all the object probabilities
        self.out_prob = torch.clamp(self.out_prob, max=0.9)
        self.out_prob[self.tar_obj] = self.obj_mask
        self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
        return self.out_prob