Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| def convert_mask_grounding_to_od_logits(logits, positive_map_label_to_token, num_classes): | |
| od_logits = torch.zeros(logits.shape[0], num_classes + 1, logits.shape[2], logits.shape[3]).to(logits.device) | |
| for label_j in positive_map_label_to_token: | |
| od_logits[:, label_j, :, :] = logits[:, torch.LongTensor(positive_map_label_to_token[label_j]), :, :].mean(1) | |
| mask_prob = od_logits.sigmoid() | |
| return mask_prob | |
| # TODO check if want to return a single BoxList or a composite | |
| # object | |
| class MaskPostProcessor(nn.Module): | |
| """ | |
| From the results of the CNN, post process the masks | |
| by taking the mask corresponding to the class with max | |
| probability (which are of fixed size and directly output | |
| by the CNN) and return the masks in the mask field of the BoxList. | |
| If a masker object is passed, it will additionally | |
| project the masks in the image according to the locations in boxes, | |
| """ | |
| def __init__(self, masker=None, mdetr_style_aggregate_class_num=None, vl_version=None): | |
| super(MaskPostProcessor, self).__init__() | |
| self.masker = masker | |
| self.mdetr_style_aggregate_class_num = mdetr_style_aggregate_class_num | |
| self.vl_version = vl_version | |
| def forward(self, x, boxes, positive_map_label_to_token=None): | |
| """ | |
| Arguments: | |
| x (Tensor): the mask logits | |
| boxes (list[BoxList]): bounding boxes that are used as | |
| reference, one for ech image | |
| Returns: | |
| results (list[BoxList]): one BoxList for each image, containing | |
| the extra field mask | |
| """ | |
| if self.vl_version: | |
| mask_prob = convert_mask_grounding_to_od_logits(x, positive_map_label_to_token, self.mdetr_style_aggregate_class_num) | |
| else: | |
| mask_prob = x.sigmoid() | |
| # select masks coresponding to the predicted classes | |
| num_masks = x.shape[0] | |
| labels = [bbox.get_field("labels") for bbox in boxes] | |
| labels = torch.cat(labels) | |
| if not self.vl_version: | |
| # TODO: a hack for binary mask head | |
| labels = (labels > 0).to(dtype=torch.int64) | |
| index = torch.arange(num_masks, device=labels.device) | |
| mask_prob = mask_prob[index, labels][:, None] | |
| boxes_per_image = [len(box) for box in boxes] | |
| mask_prob = mask_prob.split(boxes_per_image, dim=0) | |
| if self.masker: | |
| mask_prob = self.masker(mask_prob, boxes) | |
| results = [] | |
| for prob, box in zip(mask_prob, boxes): | |
| bbox = BoxList(box.bbox, box.size, mode="xyxy") | |
| for field in box.fields(): | |
| bbox.add_field(field, box.get_field(field)) | |
| bbox.add_field("mask", prob) | |
| results.append(bbox) | |
| return results | |
| class MaskPostProcessorCOCOFormat(MaskPostProcessor): | |
| """ | |
| From the results of the CNN, post process the results | |
| so that the masks are pasted in the image, and | |
| additionally convert the results to COCO format. | |
| """ | |
| def forward(self, x, boxes, positive_map_label_to_token=None, vl_version=None): | |
| import pycocotools.mask as mask_util | |
| import numpy as np | |
| results = super(MaskPostProcessorCOCOFormat, self).forward(x, boxes) | |
| for result in results: | |
| masks = result.get_field("mask").cpu() | |
| rles = [ | |
| mask_util.encode(np.array(mask[0, :, :, np.newaxis], order="F"))[0] | |
| for mask in masks | |
| ] | |
| for rle in rles: | |
| rle["counts"] = rle["counts"].decode("utf-8") | |
| result.add_field("mask", rles) | |
| return results | |
| # the next two functions should be merged inside Masker | |
| # but are kept here for the moment while we need them | |
| # temporarily gor paste_mask_in_image | |
| def expand_boxes(boxes, scale): | |
| w_half = (boxes[:, 2] - boxes[:, 0]) * .5 | |
| h_half = (boxes[:, 3] - boxes[:, 1]) * .5 | |
| x_c = (boxes[:, 2] + boxes[:, 0]) * .5 | |
| y_c = (boxes[:, 3] + boxes[:, 1]) * .5 | |
| w_half *= scale | |
| h_half *= scale | |
| boxes_exp = torch.zeros_like(boxes) | |
| boxes_exp[:, 0] = x_c - w_half | |
| boxes_exp[:, 2] = x_c + w_half | |
| boxes_exp[:, 1] = y_c - h_half | |
| boxes_exp[:, 3] = y_c + h_half | |
| return boxes_exp | |
| def expand_masks(mask, padding): | |
| N = mask.shape[0] | |
| M = mask.shape[-1] | |
| pad2 = 2 * padding | |
| scale = float(M + pad2) / M | |
| padded_mask = mask.new_zeros((N, 1, M + pad2, M + pad2)) | |
| padded_mask[:, :, padding:-padding, padding:-padding] = mask | |
| return padded_mask, scale | |
| def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1): | |
| padded_mask, scale = expand_masks(mask[None], padding=padding) | |
| mask = padded_mask[0, 0] | |
| box = expand_boxes(box[None], scale)[0] | |
| box = box.to(dtype=torch.int32) | |
| TO_REMOVE = 1 | |
| w = int(box[2] - box[0] + TO_REMOVE) | |
| h = int(box[3] - box[1] + TO_REMOVE) | |
| w = max(w, 1) | |
| h = max(h, 1) | |
| # Set shape to [batchxCxHxW] | |
| mask = mask.expand((1, 1, -1, -1)) | |
| # Resize mask | |
| mask = mask.to(torch.float32) | |
| mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) | |
| mask = mask[0][0] | |
| if thresh >= 0: | |
| mask = mask > thresh | |
| else: | |
| # for visualization and debugging, we also | |
| # allow it to return an unmodified mask | |
| mask = (mask * 255).to(torch.bool) | |
| im_mask = torch.zeros((im_h, im_w), dtype=torch.bool) | |
| x_0 = max(box[0], 0) | |
| x_1 = min(box[2] + 1, im_w) | |
| y_0 = max(box[1], 0) | |
| y_1 = min(box[3] + 1, im_h) | |
| im_mask[y_0:y_1, x_0:x_1] = mask[ | |
| (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) | |
| ] | |
| return im_mask | |
| class Masker(object): | |
| """ | |
| Projects a set of masks in an image on the locations | |
| specified by the bounding boxes | |
| """ | |
| def __init__(self, threshold=0.5, padding=1): | |
| self.threshold = threshold | |
| self.padding = padding | |
| def forward_single_image(self, masks, boxes): | |
| boxes = boxes.convert("xyxy") | |
| im_w, im_h = boxes.size | |
| res = [ | |
| paste_mask_in_image(mask[0], box, im_h, im_w, self.threshold, self.padding) | |
| for mask, box in zip(masks, boxes.bbox) | |
| ] | |
| if len(res) > 0: | |
| res = torch.stack(res, dim=0)[:, None] | |
| else: | |
| res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1])) | |
| return res | |
| def __call__(self, masks, boxes): | |
| if isinstance(boxes, BoxList): | |
| boxes = [boxes] | |
| # Make some sanity check | |
| assert len(boxes) == len(masks), "Masks and boxes should have the same length." | |
| # TODO: Is this JIT compatible? | |
| # If not we should make it compatible. | |
| results = [] | |
| for mask, box in zip(masks, boxes): | |
| assert mask.shape[0] == len(box), "Number of objects should be the same." | |
| result = self.forward_single_image(mask, box) | |
| results.append(result) | |
| return results | |
| def make_roi_mask_post_processor(cfg): | |
| if cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS: | |
| mask_threshold = cfg.MODEL.ROI_MASK_HEAD.POSTPROCESS_MASKS_THRESHOLD | |
| masker = Masker(threshold=mask_threshold, padding=1) | |
| else: | |
| masker = None | |
| mdetr_style_aggregate_class_num = cfg.TEST.MDETR_STYLE_AGGREGATE_CLASS_NUM | |
| mask_post_processor = MaskPostProcessor(masker, | |
| mdetr_style_aggregate_class_num, | |
| vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL")) | |
| return mask_post_processor | |