File size: 9,631 Bytes
c6535db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

"""Utilities for masks manipulation"""

import numpy as np
import pycocotools.mask as maskUtils
import torch
from pycocotools import mask as mask_util


def instance_masks_to_semantic_masks(
    instance_masks: torch.Tensor, num_instances: torch.Tensor
) -> torch.Tensor:
    """This function converts instance masks to semantic masks.
    It accepts a collapsed batch of instances masks (ie all instance masks are concatenated in a single tensor) and
    the number of instances in each image of the batch.
    It returns a mask with the same spatial dimensions as the input instance masks, where for each batch element the
    semantic mask is the union of all the instance masks in the batch element.

    If for a given batch element there are no instances (ie num_instances[i]==0), the corresponding semantic mask will be a tensor of zeros.

    Args:
        instance_masks (torch.Tensor): A tensor of shape (N, H, W) where N is the number of instances in the batch.
        num_instances (torch.Tensor): A tensor of shape (B,) where B is the batch size. It contains the number of instances
            in each image of the batch.

    Returns:
        torch.Tensor: A tensor of shape (B, H, W) where B is the batch size and H, W are the spatial dimensions of the
            input instance masks.
    """

    masks_per_query = torch.split(instance_masks, num_instances.tolist())

    return torch.stack([torch.any(masks, dim=0) for masks in masks_per_query], dim=0)


def mask_intersection(masks1, masks2, block_size=16):
    """Compute the intersection of two sets of masks, without blowing the memory"""

    assert masks1.shape[1:] == masks2.shape[1:]
    assert masks1.dtype == torch.bool and masks2.dtype == torch.bool

    result = torch.zeros(
        masks1.shape[0], masks2.shape[0], device=masks1.device, dtype=torch.long
    )
    for i in range(0, masks1.shape[0], block_size):
        for j in range(0, masks2.shape[0], block_size):
            intersection = (
                (masks1[i : i + block_size, None] * masks2[None, j : j + block_size])
                .flatten(-2)
                .sum(-1)
            )
            result[i : i + block_size, j : j + block_size] = intersection
    return result


def mask_iom(masks1, masks2):
    """
    Similar to IoU, except the denominator is the area of the smallest mask
    """
    assert masks1.shape[1:] == masks2.shape[1:]
    assert masks1.dtype == torch.bool and masks2.dtype == torch.bool

    # intersection = (masks1[:, None] * masks2[None]).flatten(-2).sum(-1)
    intersection = mask_intersection(masks1, masks2)
    area1 = masks1.flatten(-2).sum(-1)
    area2 = masks2.flatten(-2).sum(-1)
    min_area = torch.min(area1[:, None], area2[None, :])
    return intersection / (min_area + 1e-8)


def compute_boundary(seg):
    """
    Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L148
    Return a 1pix wide boundary of the given mask
    """
    assert seg.ndim >= 2
    e = torch.zeros_like(seg)
    s = torch.zeros_like(seg)
    se = torch.zeros_like(seg)

    e[..., :, :-1] = seg[..., :, 1:]
    s[..., :-1, :] = seg[..., 1:, :]
    se[..., :-1, :-1] = seg[..., 1:, 1:]

    b = seg ^ e | seg ^ s | seg ^ se
    b[..., -1, :] = seg[..., -1, :] ^ e[..., -1, :]
    b[..., :, -1] = seg[..., :, -1] ^ s[..., :, -1]
    b[..., -1, -1] = 0
    return b


def dilation(mask, kernel_size):
    """
    Implements the dilation operation. If the input is on cpu, we call the cv2 version.
    Otherwise, we implement it using a convolution

    The kernel is assumed to be a square kernel

    """

    assert mask.ndim == 3
    kernel_size = int(kernel_size)
    assert (
        kernel_size % 2 == 1
    ), f"Dilation expects a odd kernel size, got {kernel_size}"

    if mask.is_cuda:
        m = mask.unsqueeze(1).to(torch.float16)
        k = torch.ones(1, 1, kernel_size, 1, dtype=m.dtype, device=m.device)

        result = torch.nn.functional.conv2d(m, k, padding="same")
        result = torch.nn.functional.conv2d(result, k.transpose(-1, -2), padding="same")
        return result.view_as(mask) > 0

    all_masks = mask.view(-1, mask.size(-2), mask.size(-1)).numpy().astype(np.uint8)
    kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8)

    import cv2

    processed = [torch.from_numpy(cv2.dilate(m, kernel)) for m in all_masks]
    return torch.stack(processed).view_as(mask).to(mask)


def compute_F_measure(
    gt_boundary_rle, gt_dilated_boundary_rle, dt_boundary_rle, dt_dilated_boundary_rle
):
    """Adapted from https://github.com/JonathonLuiten/TrackEval/blob/master/trackeval/metrics/j_and_f.py#L207

    Assumes the boundary and dilated boundaries have already been computed and converted to RLE
    """
    gt_match = maskUtils.merge([gt_boundary_rle, dt_dilated_boundary_rle], True)
    dt_match = maskUtils.merge([dt_boundary_rle, gt_dilated_boundary_rle], True)

    n_dt = maskUtils.area(dt_boundary_rle)
    n_gt = maskUtils.area(gt_boundary_rle)
    # % Compute precision and recall
    if n_dt == 0 and n_gt > 0:
        precision = 1
        recall = 0
    elif n_dt > 0 and n_gt == 0:
        precision = 0
        recall = 1
    elif n_dt == 0 and n_gt == 0:
        precision = 1
        recall = 1
    else:
        precision = maskUtils.area(dt_match) / float(n_dt)
        recall = maskUtils.area(gt_match) / float(n_gt)

    # Compute F measure
    if precision + recall == 0:
        f_val = 0
    else:
        f_val = 2 * precision * recall / (precision + recall)

    return f_val


@torch.no_grad()
def rle_encode(orig_mask, return_areas=False):
    """Encodes a collection of masks in RLE format

    This function emulates the behavior of the COCO API's encode function, but
    is executed partially on the GPU for faster execution.

    Args:
        mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
        return_areas (bool): If True, add the areas of the masks as a part of
            the RLE output dict under the "area" key. Default is False.

    Returns:
        str: The RLE encoded masks
    """
    assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
    assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"

    if orig_mask.numel() == 0:
        return []

    # First, transpose the spatial dimensions.
    # This is necessary because the COCO API uses Fortran order
    mask = orig_mask.transpose(1, 2)

    # Flatten the mask
    flat_mask = mask.reshape(mask.shape[0], -1)
    if return_areas:
        mask_areas = flat_mask.sum(-1).tolist()
    # Find the indices where the mask changes
    differences = torch.ones(
        mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
    )
    differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
    differences[:, 0] = flat_mask[:, 0]
    _, change_indices = torch.where(differences)

    try:
        boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
    except RuntimeError as _:
        boundaries = torch.cumsum(differences.cpu().sum(-1), 0)

    change_indices_clone = change_indices.clone()
    # First pass computes the RLEs on GPU, in a flatten format
    for i in range(mask.shape[0]):
        # Get the change indices for this batch item
        beg = 0 if i == 0 else boundaries[i - 1].item()
        end = boundaries[i].item()
        change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]

    # Now we can split the RLES of each batch item, and convert them to strings
    # No more gpu at this point
    change_indices = change_indices.tolist()

    batch_rles = []
    # Process each mask in the batch separately
    for i in range(mask.shape[0]):
        beg = 0 if i == 0 else boundaries[i - 1].item()
        end = boundaries[i].item()
        run_lengths = change_indices[beg:end]

        uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
        h, w = uncompressed_rle["size"]
        rle = mask_util.frPyObjects(uncompressed_rle, h, w)
        rle["counts"] = rle["counts"].decode("utf-8")
        if return_areas:
            rle["area"] = mask_areas[i]
        batch_rles.append(rle)

    return batch_rles


def robust_rle_encode(masks):
    """Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""

    assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
    assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"

    try:
        return rle_encode(masks)
    except RuntimeError as _:
        masks = masks.cpu().numpy()
        rles = [
            mask_util.encode(
                np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
            )[0]
            for mask in masks
        ]
        for rle in rles:
            rle["counts"] = rle["counts"].decode("utf-8")
        return rles


def ann_to_rle(segm, im_info):
    """Convert annotation which can be polygons, uncompressed RLE to RLE.
    Args:
        ann (dict) : annotation object
    Returns:
        ann (rle)
    """
    h, w = im_info["height"], im_info["width"]
    if isinstance(segm, list):
        # polygon -- a single object might consist of multiple parts
        # we merge all parts into one mask rle code
        rles = mask_util.frPyObjects(segm, h, w)
        rle = mask_util.merge(rles)
    elif isinstance(segm["counts"], list):
        # uncompressed RLE
        rle = mask_util.frPyObjects(segm, h, w)
    else:
        # rle
        rle = segm
    return rle