File size: 13,403 Bytes
14114e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from pycocotools import mask as mask_util

from sam3.train.data.sam3_image_dataset import Datapoint
from torchvision.ops import masks_to_boxes


def sample_points_from_rle(rle, n_points, mode, box=None, normalize=True):
    """
    Sample random points from a mask provided in COCO RLE format. 'mode'
    'mode' is in ["centered", "random_mask", "random_box"]
      "centered": points are sampled farthest from the mask edges and each other
      "random_mask": points are sampled uniformly from the mask
      "random_box": points are sampled uniformly from the annotation's box
    'box' must be provided if 'mode' is "random_box".
    If 'normalize' is true, points are in [0,1], relative to mask h,w.
    """
    mask = np.ascontiguousarray(mask_util.decode(rle))
    points = sample_points_from_mask(mask, n_points, mode, box)

    if normalize:
        h, w = mask.shape
        norm = np.array([w, h, 1.0])[None, :]
        points = points / norm

    return points


def sample_points_from_mask(mask, n_points, mode, box=None):
    if mode == "centered":
        points = center_positive_sample(mask, n_points)
    elif mode == "random_mask":
        points = uniform_positive_sample(mask, n_points)
    elif mode == "random_box":
        assert box is not None, "'random_box' mode requires a provided box."
        points = uniform_sample_from_box(mask, box, n_points)
    else:
        raise ValueError(f"Unknown point sampling mode {mode}.")
    return points


def uniform_positive_sample(mask, n_points):
    """
    Samples positive points uniformly from the mask. Only integer pixel
    values are sampled.
    """
    # Sampling directly from the uncompressed RLE would be faster but is
    # likely unnecessary.
    mask_points = np.stack(np.nonzero(mask), axis=0).transpose(1, 0)
    assert len(mask_points) > 0, "Can't sample positive points from an empty mask."
    selected_idxs = np.random.randint(low=0, high=len(mask_points), size=n_points)
    selected_points = mask_points[selected_idxs]

    selected_points = selected_points[:, ::-1]  # (y, x) -> (x, y)
    labels = np.ones((len(selected_points), 1))
    selected_points = np.concatenate([selected_points, labels], axis=1)

    return selected_points


def center_positive_sample(mask, n_points):
    """
    Samples points farthest from mask edges (by distance transform)
    and subsequent points also farthest from each other. Each new point
    sampled is treated as an edge for future points. Edges of the image are
    treated as edges of the mask.
    """

    # Pad mask by one pixel on each end to assure distance transform
    # avoids edges
    padded_mask = np.pad(mask, 1)

    points = []
    for _ in range(n_points):
        assert np.max(mask) > 0, "Can't sample positive points from an empty mask."
        dist = cv2.distanceTransform(padded_mask, cv2.DIST_L2, 0)
        point = np.unravel_index(dist.argmax(), dist.shape)
        # Mark selected point as background so next point avoids it
        padded_mask[point[0], point[1]] = 0
        points.append(point[::-1])  # (y, x) -> (x, y)

    points = np.stack(points, axis=0)
    points = points - 1  # Subtract left/top padding of 1
    labels = np.ones((len(points), 1))
    points = np.concatenate([points, labels], axis=1)

    return points


def uniform_sample_from_box(mask, box, n_points):
    """
    Sample points uniformly from the provided box. The points' labels
    are determined by the provided mask. Does not guarantee a positive
    point is sampled. The box is assumed unnormalized in XYXY format.
    Points are sampled at integer values.
    """

    # Since lower/right edges are exclusive, ceil can be applied to all edges
    int_box = np.ceil(box)

    x = np.random.randint(low=int_box[0], high=int_box[2], size=n_points)
    y = np.random.randint(low=int_box[1], high=int_box[3], size=n_points)
    labels = mask[y, x]
    points = np.stack([x, y, labels], axis=1)

    return points


def rescale_box_xyxy(box, factor, imsize=None):
    """
    Rescale a box providing in unnormalized XYXY format, fixing the center.
    If imsize is provided, clamp to the image.
    """
    cx, cy = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
    w, h = box[2] - box[0], box[3] - box[1]

    new_w, new_h = factor * w, factor * h

    new_x0, new_y0 = cx - new_w / 2, cy - new_h / 2
    new_x1, new_y1 = cx + new_w / 2, cy + new_h / 2

    if imsize is not None:
        new_x0 = max(min(new_x0, imsize[1]), 0)
        new_x1 = max(min(new_x1, imsize[1]), 0)
        new_y0 = max(min(new_y0, imsize[0]), 0)
        new_y1 = max(min(new_y1, imsize[0]), 0)

    return [new_x0, new_y0, new_x1, new_y1]


def noise_box(box, im_size, box_noise_std, box_noise_max, min_box_area):
    if box_noise_std <= 0.0:
        return box
    noise = box_noise_std * torch.randn(size=(4,))
    w, h = box[2] - box[0], box[3] - box[1]
    scale_factor = torch.tensor([w, h, w, h])
    noise = noise * scale_factor
    if box_noise_max is not None:
        noise = torch.clamp(noise, -box_noise_max, box_noise_max)
    input_box = box + noise
    # Clamp to maximum image size
    img_clamp = torch.tensor([im_size[1], im_size[0], im_size[1], im_size[0]])
    input_box = torch.maximum(input_box, torch.zeros_like(input_box))
    input_box = torch.minimum(input_box, img_clamp)
    if (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) <= min_box_area:
        return box

    return input_box


class RandomGeometricInputsAPI:
    """
    For geometric queries, replaces the input box or points with a random
    one sampled from the GT mask. Segments must be provided for objects
    that are targets of geometric queries, and must be binary masks. Existing
    point and box queries in the datapoint will be ignored and completely replaced.
    Will sample points and boxes in XYXY format in absolute pixel space.

    Geometry queries are currently determined by taking any query whose
    query text is a set value.

    Args:
      num_points (int or (int, int)): how many points to sample. If a tuple,
        sample a random number of points uniformly over the inclusive range.
      box_chance (float): fraction of time a box is sampled. A box will replace
        one sampled point.
      box_noise_std (float): if greater than 0, add noise to the sampled boxes
        with this std. Noise is relative to the length of the box side.
      box_noise_max (int): if not none, truncate any box noise larger than this
        in terms of absolute pixels.
      resample_box_from_mask (bool): if True, any sampled box will be determined
        by finding the extrema of the provided mask. If False, the bbox provided
        in the target object will be used.
      point_sample_mode (str): In ["centered", "random_mask", "random_box"],
        controlling how points are sampled:
          "centered": points are sampled farthest from the mask edges and each other
          "random_mask": points are sampled uniformly from the mask
          "random_box": points are sampled uniformly from the annotation's box
        Note that "centered" may be too slow for on-line generation.
      geometric_query_str (str): what string in query_text indicates a
        geometry query.
      minimum_box_area (float): sampled boxes with area this size or smaller after
        noising will use the original box instead. It is the input's responsibility
        to avoid original boxes that violate necessary area bounds.
      concat_points (bool): if True, any sampled points will be added to existing
        ones instead of replacing them.

    """

    def __init__(
        self,
        num_points,
        box_chance,
        box_noise_std=0.0,
        box_noise_max=None,
        minimum_box_area=0.0,
        resample_box_from_mask=False,
        point_sample_mode="random_mask",
        sample_box_scale_factor=1.0,
        geometric_query_str="geometric",
        concat_points=False,
    ):
        self.num_points = num_points
        if not isinstance(self.num_points, int):
            # Convert from inclusive range to exclusive range expected by torch
            self.num_points[1] += 1
            self.num_points = tuple(self.num_points)
        self.box_chance = box_chance
        self.box_noise_std = box_noise_std
        self.box_noise_max = box_noise_max
        self.minimum_box_area = minimum_box_area
        self.resample_box_from_mask = resample_box_from_mask
        self.point_sample_mode = point_sample_mode
        assert point_sample_mode in [
            "centered",
            "random_mask",
            "random_box",
        ], "Unknown point sample mode."
        self.geometric_query_str = geometric_query_str
        self.concat_points = concat_points
        self.sample_box_scale_factor = sample_box_scale_factor

    def _sample_num_points_and_if_box(self):
        if isinstance(self.num_points, tuple):
            n_points = torch.randint(
                low=self.num_points[0], high=self.num_points[1], size=(1,)
            ).item()
        else:
            n_points = self.num_points
        if self.box_chance > 0.0:
            use_box = torch.rand(size=(1,)).item() < self.box_chance
            n_points -= int(use_box)  # box stands in for one point
        else:
            use_box = False
        return n_points, use_box

    def _get_original_box(self, target_object):
        if not self.resample_box_from_mask:
            return target_object.bbox
        mask = target_object.segment
        return masks_to_boxes(mask[None, :, :])[0]

    def _get_target_object(self, datapoint, query):
        img = datapoint.images[query.image_id]
        targets = query.object_ids_output
        assert (
            len(targets) == 1
        ), "Geometric queries only support a single target object."
        target_idx = targets[0]
        return img.objects[target_idx]

    def __call__(self, datapoint, **kwargs):
        for query in datapoint.find_queries:
            if query.query_text != self.geometric_query_str:
                continue

            target_object = self._get_target_object(datapoint, query)
            n_points, use_box = self._sample_num_points_and_if_box()
            box = self._get_original_box(target_object)

            mask = target_object.segment
            if n_points > 0:
                # FIXME: The conversion to numpy and back to reuse code
                # is awkward, but this is all in the dataloader worker anyway
                # on CPU and so I don't think it should matter.
                if self.sample_box_scale_factor != 1.0:
                    sample_box = rescale_box_xyxy(
                        box.numpy(), self.sample_box_scale_factor, mask.shape
                    )
                else:
                    sample_box = box.numpy()
                input_points = sample_points_from_mask(
                    mask.numpy(),
                    n_points,
                    self.point_sample_mode,
                    sample_box,
                )
                input_points = torch.as_tensor(input_points)
                input_points = input_points[None, :, :]
                if self.concat_points and query.input_points is not None:
                    input_points = torch.cat([query.input_points, input_points], dim=1)
            else:
                input_points = query.input_points if self.concat_points else None

            if use_box:
                w, h = datapoint.images[query.image_id].size
                input_box = noise_box(
                    box,
                    (h, w),
                    box_noise_std=self.box_noise_std,
                    box_noise_max=self.box_noise_max,
                    min_box_area=self.minimum_box_area,
                )
                input_box = input_box[None, :]
            else:
                input_box = query.input_bbox if self.concat_points else None

            query.input_points = input_points
            query.input_bbox = input_box

        return datapoint


class RandomizeInputBbox:
    """
    Simplified version of the geometric transform that only deals with input boxes
    """

    def __init__(
        self,
        box_noise_std=0.0,
        box_noise_max=None,
        minimum_box_area=0.0,
    ):
        self.box_noise_std = box_noise_std
        self.box_noise_max = box_noise_max
        self.minimum_box_area = minimum_box_area

    def __call__(self, datapoint: Datapoint, **kwargs):
        for query in datapoint.find_queries:
            if query.input_bbox is None:
                continue

            img = datapoint.images[query.image_id].data
            if isinstance(img, PILImage.Image):
                w, h = img.size
            else:
                assert isinstance(img, torch.Tensor)
                h, w = img.shape[-2:]

            for box_id in range(query.input_bbox.shape[0]):
                query.input_bbox[box_id, :] = noise_box(
                    query.input_bbox[box_id, :].view(4),
                    (h, w),
                    box_noise_std=self.box_noise_std,
                    box_noise_max=self.box_noise_max,
                    min_box_area=self.minimum_box_area,
                ).view(1, 4)

        return datapoint