import numpy as np import torch from torch import nn from torch.nn import functional as F from util.poly_ops import get_all_order_corners try: from diff_ras.polygon import SoftPolygon except ImportError: SoftPolygon = None from util.bf_utils import POLY_LOSS_REGISTRY, rasterize_instances def custom_L1_loss(src_polys, target_polys, target_len): """L1 loss for coordinates regression We only calculate the loss between valid corners since we filter out invalid corners in final results Args: src_polys: Tensor of dim [num_target_polys, num_queries_per_poly*2] with the matched predicted polygons coordinates target_polys: Tensor of dim [num_target_polys, num_queries_per_poly*2] with the target polygons coordinates target_len: list of size num_target_polys, each element indicates 2 * num_corners of this poly """ total_loss = 0.0 for i in range(target_polys.shape[0]): tgt_poly_single = target_polys[i, : target_len[i]] all_polys = get_all_order_corners(tgt_poly_single) total_loss += torch.cdist(src_polys[i, : target_len[i]].unsqueeze(0), all_polys, p=1).min() total_loss = total_loss / target_len.sum() return total_loss class ClippingStrategy(nn.Module): def __init__(self, cfg, is_boundary=False): super().__init__() self.register_buffer( "laplacian", torch.tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], dtype=torch.float32).reshape(1, 1, 3, 3) ) self.is_boundary = is_boundary self.side_lengths = np.array([64, 64, 64, 64, 64, 64, 64, 64]).reshape(-1, 2) # not used. def _extract_target_boundary(self, masks, shape): boundary_targets = F.conv2d(masks.unsqueeze(1), self.laplacian, padding=1) boundary_targets = boundary_targets.clamp(min=0) boundary_targets[boundary_targets > 0.1] = 1 boundary_targets[boundary_targets <= 0.1] = 0 # odd? only if the width doesn't match? if boundary_targets.shape[-2:] != shape: boundary_targets = F.interpolate(boundary_targets, shape, mode="nearest") return boundary_targets def forward(self, instances, clip_boxes=None, lid=0): device = self.laplacian.device gt_masks = [] if clip_boxes is not None: clip_boxes = torch.split(clip_boxes, [len(inst) for inst in instances], dim=0) for idx, instances_per_image in enumerate(instances): if len(instances_per_image) == 0: continue if clip_boxes is not None: # todo, need to support rectangular boxes. gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize( clip_boxes[idx].detach(), self.side_lengths[lid][0] ) else: gt_masks_per_image = instances_per_image.gt_masks.rasterize_no_crop(self.side_length).to(device) # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len gt_masks.append(gt_masks_per_image) return torch.cat(gt_masks).squeeze(1) def dice_loss(input, target): smooth = 1.0 iflat = input.reshape(-1) tflat = target.reshape(-1) intersection = (iflat * tflat).sum() return 1 - ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) def dice_loss_no_reduction(input, target): smooth = 1.0 iflat = input.flatten(-2, -1) # [200, 4096] tflat = target.flatten(-2, -1) # [200, 4096] intersection = (iflat * tflat).sum(1) # [200] return 1 - ((2.0 * intersection + smooth) / (iflat.sum(1) + tflat.sum(1) + smooth)) @POLY_LOSS_REGISTRY.register() class MaskRasterizationLoss(nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer( "rasterize_at", torch.from_numpy(np.array([64, 64, 64, 64, 64, 64, 64, 64]).reshape(-1, 2)) ) # self.register_buffer("rasterize_at", torch.from_numpy(np.array([128, 128, 128, 128, 128, 128, 128, 128]).reshape(-1, 2))) # self.register_buffer("rasterize_at", torch.from_numpy(np.array([256, 256, 256, 256, 256, 256, 256, 256]).reshape(-1, 2))) self.inv_smoothness_schedule = (0.1,) self.inv_smoothness = self.inv_smoothness_schedule[0] self.inv_smoothness_iter = () self.inv_smoothness_idx = 0 self.iter = 0 # whether to invoke our own rasterizer in "hard" mode. self.use_rasterized_gt = True self.pred_rasterizer = SoftPolygon(inv_smoothness=self.inv_smoothness, mode="mask") self.clip_to_proposal = False self.predict_in_box_space = True if self.clip_to_proposal or not self.use_rasterized_gt: self.clipper = ClippingStrategy(cfg=None) self.gt_rasterizer = None else: self.gt_rasterizer = SoftPolygon(inv_smoothness=1.0, mode="hard_mask") self.offset = 0.5 self.loss_fn = dice_loss self.name = "mask" def _create_targets(self, instances, clip_boxes=None, lid=0): if self.clip_to_proposal or not self.use_rasterized_gt: targets = self.clipper(instances, clip_boxes=clip_boxes, lid=lid) else: targets = rasterize_instances(self.gt_rasterizer, instances, self.rasterize_at) return targets def forward(self, preds, targets, target_len, lid=0): resolution = self.rasterize_at[lid] target_masks = [] pred_masks = [] for i in range(len(targets)): # tgt_poly_single = targets[i, :target_len[i]].view(-1, 2).unsqueeze(0) # pred_poly_single = preds[i, :target_len[i]].view(-1, 2).unsqueeze(0) tgt_poly_single = targets[i][: target_len[i]].view(-1, 2).unsqueeze(0) pred_poly_single = preds[i][: target_len[i]].view(-1, 2).unsqueeze(0) tgt_mask = self.gt_rasterizer( tgt_poly_single * float(resolution[1].item()), resolution[1].item(), resolution[0].item(), 1.0 ) tgt_mask = (tgt_mask + 1) / 2 pred_mask = self.pred_rasterizer( pred_poly_single * float(resolution[1].item()), resolution[1].item(), resolution[0].item(), 1.0 ) target_masks.append(tgt_mask) pred_masks.append(pred_mask) pred_masks = torch.stack(pred_masks) target_masks = torch.stack(target_masks) return self.loss_fn(pred_masks, target_masks) class MaskRasterizationCost(nn.Module): def __init__(self, cfg): super().__init__() self.register_buffer( "rasterize_at", torch.from_numpy(np.array([64, 64, 64, 64, 64, 64, 64, 64]).reshape(-1, 2)) ) # self.register_buffer("rasterize_at", torch.from_numpy(np.array([128, 128, 128, 128, 128, 128, 128, 128]).reshape(-1, 2))) self.inv_smoothness_schedule = (0.1,) self.inv_smoothness = self.inv_smoothness_schedule[0] self.inv_smoothness_iter = () self.inv_smoothness_idx = 0 self.iter = 0 self.pred_rasterizer = SoftPolygon(inv_smoothness=self.inv_smoothness, mode="mask") # whether to invoke our own rasterizer in "hard" mode. self.use_rasterized_gt = True self.gt_rasterizer = SoftPolygon(inv_smoothness=1.0, mode="hard_mask") self.offset = 0.5 self.loss_fn = dice_loss_no_reduction self.name = "mask" def mask_iou( self, mask1: torch.Tensor, mask2: torch.Tensor, ) -> torch.Tensor: """ Inputs: mask1: NxHxW torch.float32. Consists of [0, 1] mask2: NxHxW torch.float32. Consists of [0, 1] Outputs: ret: NxM torch.float32. Consists of [0 - 1] """ N, H, W = mask1.shape M, H, W = mask2.shape mask1 = mask1.view(N, H * W) mask2 = mask2.view(M, H * W) intersection = torch.matmul(mask1, mask2.t()) area1 = mask1.sum(dim=1).view(1, -1) area2 = mask2.sum(dim=1).view(1, -1) union = (area1.t() + area2) - intersection ret = torch.where( union == 0, torch.tensor(0.0, device=mask1.device), intersection / union, ) return ret def forward(self, preds, targets, target_len, lid=0): resolution = self.rasterize_at[lid] cost_mask = torch.zeros([preds.shape[0], targets.shape[0]], device=preds.device) pred_masks = [] for i in range(targets.shape[0]): tgt_poly_single = targets[i, : target_len[i]].view(-1, 2).unsqueeze(0) pred_poly_all = preds[:, : target_len[i]].view(preds.shape[0], -1, 2) tgt_mask = self.gt_rasterizer( tgt_poly_single * float(resolution[1].item()), resolution[1].item(), resolution[0].item(), 1.0 ) pred_masks = self.pred_rasterizer( pred_poly_all * float(resolution[1].item()), resolution[1].item(), resolution[0].item(), 1.0 ) tgt_mask = (tgt_mask + 1) / 2 tgt_masks = tgt_mask.repeat(preds.shape[0], 1, 1) cost_mask[:, i] = self.loss_fn(tgt_masks, pred_masks) return cost_mask