Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| from torch.nn import functional as F | |
| from maskrcnn_benchmark.layers import smooth_l1_loss | |
| from maskrcnn_benchmark.modeling.matcher import Matcher | |
| from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
| from maskrcnn_benchmark.modeling.utils import cat | |
| def project_masks_on_boxes(segmentation_masks, proposals, discretization_size): | |
| """ | |
| Given segmentation masks and the bounding boxes corresponding | |
| to the location of the masks in the image, this function | |
| crops and resizes the masks in the position defined by the | |
| boxes. This prepares the masks for them to be fed to the | |
| loss computation as the targets. | |
| Arguments: | |
| segmentation_masks: an instance of SegmentationMask | |
| proposals: an instance of BoxList | |
| """ | |
| masks = [] | |
| M = discretization_size | |
| device = proposals.bbox.device | |
| proposals = proposals.convert("xyxy") | |
| assert segmentation_masks.size == proposals.size, "{}, {}".format( | |
| segmentation_masks, proposals | |
| ) | |
| # TODO put the proposals on the CPU, as the representation for the | |
| # masks is not efficient GPU-wise (possibly several small tensors for | |
| # representing a single instance mask) | |
| proposals = proposals.bbox.to(torch.device("cpu")) | |
| for segmentation_mask, proposal in zip(segmentation_masks, proposals): | |
| # crop the masks, resize them to the desired resolution and | |
| # then convert them to the tensor representation, | |
| # instead of the list representation that was used | |
| cropped_mask = segmentation_mask.crop(proposal) | |
| scaled_mask = cropped_mask.resize((M, M)) | |
| mask = scaled_mask.convert(mode="mask") | |
| masks.append(mask) | |
| if len(masks) == 0: | |
| return torch.empty(0, dtype=torch.float32, device=device) | |
| return torch.stack(masks, dim=0).to(device, dtype=torch.float32) | |
| class MaskRCNNLossComputation(object): | |
| def __init__(self, proposal_matcher, discretization_size, vl_version=False): | |
| """ | |
| Arguments: | |
| proposal_matcher (Matcher) | |
| discretization_size (int) | |
| """ | |
| self.proposal_matcher = proposal_matcher | |
| self.discretization_size = discretization_size | |
| self.vl_version = vl_version | |
| def match_targets_to_proposals(self, proposal, target): | |
| match_quality_matrix = boxlist_iou(target, proposal) | |
| matched_idxs = self.proposal_matcher(match_quality_matrix) | |
| # Mask RCNN needs "labels" and "masks "fields for creating the targets | |
| if self.vl_version: | |
| target = target.copy_with_fields(["positive_map", "masks"]) | |
| else: | |
| target = target.copy_with_fields(["labels", "masks"]) | |
| # get the targets corresponding GT for each proposal | |
| # NB: need to clamp the indices because we can have a single | |
| # GT in the image, and matched_idxs can be -2, which goes | |
| # out of bounds | |
| matched_targets = target[matched_idxs.clamp(min=0)] | |
| matched_targets.add_field("matched_idxs", matched_idxs) | |
| return matched_targets | |
| def prepare_targets(self, proposals, targets): | |
| labels = [] | |
| masks = [] | |
| positive_maps = [] | |
| for proposals_per_image, targets_per_image in zip(proposals, targets): | |
| matched_targets = self.match_targets_to_proposals( | |
| proposals_per_image, targets_per_image | |
| ) | |
| matched_idxs = matched_targets.get_field("matched_idxs") | |
| if self.vl_version: | |
| positive_maps_per_image = matched_targets.get_field("positive_map") | |
| # this can probably be removed, but is left here for clarity | |
| # and completeness | |
| neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD | |
| positive_maps_per_image[neg_inds, :] = 0 | |
| positive_maps.append(positive_maps_per_image) | |
| # TODO: make sure for the softmax [NoObj] case | |
| labels_per_image = positive_maps_per_image.sum(dim=-1) | |
| labels_per_image = labels_per_image.to(dtype=torch.int64) | |
| else: | |
| labels_per_image = matched_targets.get_field("labels") | |
| labels_per_image = labels_per_image.to(dtype=torch.int64) | |
| # this can probably be removed, but is left here for clarity | |
| # and completeness | |
| neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD | |
| labels_per_image[neg_inds] = 0 | |
| # mask scores are only computed on positive samples | |
| positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) | |
| segmentation_masks = matched_targets.get_field("masks") | |
| segmentation_masks = segmentation_masks[positive_inds] | |
| positive_proposals = proposals_per_image[positive_inds] | |
| masks_per_image = project_masks_on_boxes( | |
| segmentation_masks, positive_proposals, self.discretization_size | |
| ) | |
| labels.append(labels_per_image) | |
| masks.append(masks_per_image) | |
| return labels, masks, positive_maps | |
| def __call__(self, proposals, mask_logits, targets): | |
| """ | |
| Arguments: | |
| proposals (list[BoxList]) | |
| mask_logits (Tensor) | |
| targets (list[BoxList]) | |
| Return: | |
| mask_loss (Tensor): scalar tensor containing the loss | |
| """ | |
| labels, mask_targets, positive_maps = self.prepare_targets(proposals, targets) | |
| labels = cat(labels, dim=0) | |
| mask_targets = cat(mask_targets, dim=0) | |
| positive_inds = torch.nonzero(labels > 0).squeeze(1) | |
| labels_pos = labels[positive_inds] | |
| # TODO: a hack for binary mask head | |
| labels_pos = (labels_pos > 0).to(dtype=torch.int64) | |
| # torch.mean (in binary_cross_entropy_with_logits) doesn't | |
| # accept empty tensors, so handle it separately | |
| if mask_targets.numel() == 0: | |
| return mask_logits.sum() * 0 | |
| if self.vl_version: | |
| positive_maps = cat(positive_maps, dim=0) | |
| mask_logits_pos = [] | |
| for positive_ind in positive_inds: | |
| positive_map = positive_maps[positive_ind] | |
| # TODO: make sure for the softmax [NoObj] case | |
| mask_logit_pos = mask_logits[positive_ind][torch.nonzero(positive_map).squeeze(1)].mean(dim=0, keepdim=True) | |
| mask_logits_pos.append(mask_logit_pos) | |
| mask_logits_pos = cat(mask_logits_pos, dim=0) | |
| mask_loss = F.binary_cross_entropy_with_logits( | |
| mask_logits_pos, mask_targets | |
| ) | |
| else: | |
| mask_loss = F.binary_cross_entropy_with_logits( | |
| mask_logits[positive_inds, labels_pos], mask_targets | |
| ) | |
| return mask_loss | |
| def make_roi_mask_loss_evaluator(cfg): | |
| matcher = Matcher( | |
| cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, | |
| cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, | |
| allow_low_quality_matches=False, | |
| ) | |
| loss_evaluator = MaskRCNNLossComputation( | |
| matcher, cfg.MODEL.ROI_MASK_HEAD.RESOLUTION, | |
| vl_version=cfg.MODEL.ROI_MASK_HEAD.PREDICTOR.startswith("VL") | |
| ) | |
| return loss_evaluator | |