| | |
| | import numpy as np |
| | from typing import Tuple |
| | import torch |
| | from PIL import Image |
| | from torch.nn import functional as F |
| |
|
| | __all__ = ["paste_masks_in_image"] |
| |
|
| |
|
| | BYTES_PER_FLOAT = 4 |
| | |
| | |
| | GPU_MEM_LIMIT = 1024**3 |
| |
|
| |
|
| | def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): |
| | """ |
| | Args: |
| | masks: N, 1, H, W |
| | boxes: N, 4 |
| | img_h, img_w (int): |
| | skip_empty (bool): only paste masks within the region that |
| | tightly bound all boxes, and returns the results this region only. |
| | An important optimization for CPU. |
| | |
| | Returns: |
| | if skip_empty == False, a mask of shape (N, img_h, img_w) |
| | if skip_empty == True, a mask of shape (N, h', w'), and the slice |
| | object for the corresponding region. |
| | """ |
| | |
| | |
| | |
| | |
| | device = masks.device |
| |
|
| | if skip_empty and not torch.jit.is_scripting(): |
| | x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( |
| | dtype=torch.int32 |
| | ) |
| | x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) |
| | y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) |
| | else: |
| | x0_int, y0_int = 0, 0 |
| | x1_int, y1_int = img_w, img_h |
| | x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) |
| |
|
| | N = masks.shape[0] |
| |
|
| | img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 |
| | img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 |
| | img_y = (img_y - y0) / (y1 - y0) * 2 - 1 |
| | img_x = (img_x - x0) / (x1 - x0) * 2 - 1 |
| | |
| |
|
| | gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) |
| | gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) |
| | grid = torch.stack([gx, gy], dim=3) |
| |
|
| | if not torch.jit.is_scripting(): |
| | if not masks.dtype.is_floating_point: |
| | masks = masks.float() |
| | img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) |
| |
|
| | if skip_empty and not torch.jit.is_scripting(): |
| | return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) |
| | else: |
| | return img_masks[:, 0], () |
| |
|
| |
|
| | |
| | @torch.jit.script_if_tracing |
| | def paste_masks_in_image( |
| | masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5 |
| | ): |
| | """ |
| | Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. |
| | The location, height, and width for pasting each mask is determined by their |
| | corresponding bounding boxes in boxes. |
| | |
| | Note: |
| | This is a complicated but more accurate implementation. In actual deployment, it is |
| | often enough to use a faster but less accurate implementation. |
| | See :func:`paste_mask_in_image_old` in this file for an alternative implementation. |
| | |
| | Args: |
| | masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of |
| | detected object instances in the image and Hmask, Wmask are the mask width and mask |
| | height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1]. |
| | boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4). |
| | boxes[i] and masks[i] correspond to the same object instance. |
| | image_shape (tuple): height, width |
| | threshold (float): A threshold in [0, 1] for converting the (soft) masks to |
| | binary masks. |
| | |
| | Returns: |
| | img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the |
| | number of detected object instances and Himage, Wimage are the image width |
| | and height. img_masks[i] is a binary mask for object instance i. |
| | """ |
| |
|
| | assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported" |
| | N = len(masks) |
| | if N == 0: |
| | return masks.new_empty((0,) + image_shape, dtype=torch.uint8) |
| | if not isinstance(boxes, torch.Tensor): |
| | boxes = boxes.tensor |
| | device = boxes.device |
| | assert len(boxes) == N, boxes.shape |
| |
|
| | img_h, img_w = image_shape |
| |
|
| | |
| | |
| | if device.type == "cpu" or torch.jit.is_scripting(): |
| | |
| | |
| | num_chunks = N |
| | else: |
| | |
| | |
| | num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) |
| | assert ( |
| | num_chunks <= N |
| | ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it" |
| | chunks = torch.chunk(torch.arange(N, device=device), num_chunks) |
| |
|
| | img_masks = torch.zeros( |
| | N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8 |
| | ) |
| | for inds in chunks: |
| | masks_chunk, spatial_inds = _do_paste_mask( |
| | masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu" |
| | ) |
| |
|
| | if threshold >= 0: |
| | masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) |
| | else: |
| | |
| | masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) |
| |
|
| | if torch.jit.is_scripting(): |
| | img_masks[inds] = masks_chunk |
| | else: |
| | img_masks[(inds,) + spatial_inds] = masks_chunk |
| | return img_masks |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def paste_mask_in_image_old(mask, box, img_h, img_w, threshold): |
| | """ |
| | Paste a single mask in an image. |
| | This is a per-box implementation of :func:`paste_masks_in_image`. |
| | This function has larger quantization error due to incorrect pixel |
| | modeling and is not used any more. |
| | |
| | Args: |
| | mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single |
| | object instance. Values are in [0, 1]. |
| | box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners |
| | of the object instance. |
| | img_h, img_w (int): Image height and width. |
| | threshold (float): Mask binarization threshold in [0, 1]. |
| | |
| | Returns: |
| | im_mask (Tensor): |
| | The resized and binarized object mask pasted into the original |
| | image plane (a tensor of shape (img_h, img_w)). |
| | """ |
| | |
| | |
| | |
| | box = box.to(dtype=torch.int32) |
| | |
| | |
| | |
| | samples_w = box[2] - box[0] + 1 |
| | samples_h = box[3] - box[1] + 1 |
| |
|
| | |
| | mask = Image.fromarray(mask.cpu().numpy()) |
| | mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR) |
| | mask = np.array(mask, copy=False) |
| |
|
| | if threshold >= 0: |
| | mask = np.array(mask > threshold, dtype=np.uint8) |
| | mask = torch.from_numpy(mask) |
| | else: |
| | |
| | |
| | mask = torch.from_numpy(mask * 255).to(torch.uint8) |
| |
|
| | im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8) |
| | x_0 = max(box[0], 0) |
| | x_1 = min(box[2] + 1, img_w) |
| | y_0 = max(box[1], 0) |
| | y_1 = min(box[3] + 1, img_h) |
| |
|
| | im_mask[y_0:y_1, x_0:x_1] = mask[ |
| | (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) |
| | ] |
| | return im_mask |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def pad_masks(masks, padding): |
| | """ |
| | Args: |
| | masks (tensor): A tensor of shape (B, M, M) representing B masks. |
| | padding (int): Number of cells to pad on all sides. |
| | |
| | Returns: |
| | The padded masks and the scale factor of the padding size / original size. |
| | """ |
| | B = masks.shape[0] |
| | M = masks.shape[-1] |
| | pad2 = 2 * padding |
| | scale = float(M + pad2) / M |
| | padded_masks = masks.new_zeros((B, M + pad2, M + pad2)) |
| | padded_masks[:, padding:-padding, padding:-padding] = masks |
| | return padded_masks, scale |
| |
|
| |
|
| | def scale_boxes(boxes, scale): |
| | """ |
| | Args: |
| | boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4 |
| | coords representing the corners x0, y0, x1, y1, |
| | scale (float): The box scaling factor. |
| | |
| | Returns: |
| | Scaled boxes. |
| | """ |
| | w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 |
| | h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 |
| | x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 |
| | y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 |
| |
|
| | w_half *= scale |
| | h_half *= scale |
| |
|
| | scaled_boxes = torch.zeros_like(boxes) |
| | scaled_boxes[:, 0] = x_c - w_half |
| | scaled_boxes[:, 2] = x_c + w_half |
| | scaled_boxes[:, 1] = y_c - h_half |
| | scaled_boxes[:, 3] = y_c + h_half |
| | return scaled_boxes |
| |
|
| |
|
| | @torch.jit.script_if_tracing |
| | def _paste_masks_tensor_shape( |
| | masks: torch.Tensor, |
| | boxes: torch.Tensor, |
| | image_shape: Tuple[torch.Tensor, torch.Tensor], |
| | threshold: float = 0.5, |
| | ): |
| | """ |
| | A wrapper of paste_masks_in_image where image_shape is Tensor. |
| | During tracing, shapes might be tensors instead of ints. The Tensor->int |
| | conversion should be scripted rather than traced. |
| | """ |
| | return paste_masks_in_image(masks, boxes, (int(image_shape[0]), int(image_shape[1])), threshold) |
| |
|