| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from scipy.optimize import linear_sum_assignment |
| from torch.cuda.amp import autocast |
|
|
| from mmdet.registry import MODELS, TASK_UTILS |
| from mmdet.utils import reduce_mean |
|
|
|
|
| def compute_mask_iou(inputs, targets): |
| inputs = inputs.sigmoid() |
| |
| binarized_inputs = (inputs >= 0.4).float() |
| targets = (targets > 0.5).float() |
| intersection = (binarized_inputs * targets).sum(-1) |
| union = targets.sum(-1) + binarized_inputs.sum(-1) - intersection |
| score = intersection / (union + 1e-6) |
| return score |
|
|
|
|
| def dice_score(inputs, targets): |
| inputs = inputs.sigmoid() |
| numerator = 2 * torch.matmul(inputs, targets.t()) |
| denominator = (inputs * inputs).sum(-1)[:, |
| None] + (targets * targets).sum(-1) |
| score = numerator / (denominator + 1e-4) |
| return score |
|
|
|
|
| @MODELS.register_module() |
| class SparseInstCriterion(nn.Module): |
| """This part is partially derivated from: |
| |
| https://github.com/facebookresearch/detr/blob/main/models/detr.py. |
| """ |
|
|
| def __init__( |
| self, |
| num_classes, |
| assigner, |
| loss_cls=dict( |
| type='FocalLoss', |
| use_sigmoid=True, |
| alpha=0.25, |
| gamma=2.0, |
| reduction='sum', |
| loss_weight=2.0), |
| loss_obj=dict( |
| type='CrossEntropyLoss', |
| use_sigmoid=True, |
| reduction='mean', |
| loss_weight=1.0), |
| loss_mask=dict( |
| type='CrossEntropyLoss', |
| use_sigmoid=True, |
| reduction='mean', |
| loss_weight=5.0), |
| loss_dice=dict( |
| type='DiceLoss', |
| use_sigmoid=True, |
| reduction='sum', |
| eps=5e-5, |
| loss_weight=2.0), |
| ): |
| super().__init__() |
| self.matcher = TASK_UTILS.build(assigner) |
| self.num_classes = num_classes |
| self.loss_cls = MODELS.build(loss_cls) |
| self.loss_obj = MODELS.build(loss_obj) |
| self.loss_mask = MODELS.build(loss_mask) |
| self.loss_dice = MODELS.build(loss_dice) |
|
|
| def _get_src_permutation_idx(self, indices): |
| |
| batch_idx = torch.cat( |
| [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) |
| src_idx = torch.cat([src for (src, _) in indices]) |
| return batch_idx, src_idx |
|
|
| def _get_tgt_permutation_idx(self, indices): |
| |
| batch_idx = torch.cat( |
| [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) |
| return batch_idx, tgt_idx |
|
|
| def loss_classification(self, outputs, batch_gt_instances, indices, |
| num_instances): |
| assert 'pred_logits' in outputs |
| src_logits = outputs['pred_logits'] |
| idx = self._get_src_permutation_idx(indices) |
| target_classes_o = torch.cat( |
| [gt.labels[J] for gt, (_, J) in zip(batch_gt_instances, indices)]) |
| target_classes = torch.full( |
| src_logits.shape[:2], |
| self.num_classes, |
| dtype=torch.int64, |
| device=src_logits.device) |
| target_classes[idx] = target_classes_o |
|
|
| src_logits = src_logits.flatten(0, 1) |
| target_classes = target_classes.flatten(0, 1) |
| |
| class_loss = self.loss_cls( |
| src_logits, |
| target_classes, |
| ) / num_instances |
| return class_loss |
|
|
| def loss_masks_with_iou_objectness(self, outputs, batch_gt_instances, |
| indices, num_instances): |
| src_idx = self._get_src_permutation_idx(indices) |
| tgt_idx = self._get_tgt_permutation_idx(indices) |
| |
| assert 'pred_masks' in outputs |
| assert 'pred_scores' in outputs |
| src_iou_scores = outputs['pred_scores'] |
| src_masks = outputs['pred_masks'] |
| with torch.no_grad(): |
| target_masks = torch.cat([ |
| gt.masks.to_tensor( |
| dtype=src_masks.dtype, device=src_masks.device) |
| for gt in batch_gt_instances |
| ]) |
| num_masks = [len(gt.masks) for gt in batch_gt_instances] |
| target_masks = target_masks.to(src_masks) |
| if len(target_masks) == 0: |
|
|
| loss_dice = src_masks.sum() * 0.0 |
| loss_mask = src_masks.sum() * 0.0 |
| loss_objectness = src_iou_scores.sum() * 0.0 |
|
|
| return loss_objectness, loss_dice, loss_mask |
|
|
| src_masks = src_masks[src_idx] |
| target_masks = F.interpolate( |
| target_masks[:, None], |
| size=src_masks.shape[-2:], |
| mode='bilinear', |
| align_corners=False).squeeze(1) |
|
|
| src_masks = src_masks.flatten(1) |
| |
| mix_tgt_idx = torch.zeros_like(tgt_idx[1]) |
| cum_sum = 0 |
| for num_mask in num_masks: |
| mix_tgt_idx[cum_sum:cum_sum + num_mask] = cum_sum |
| cum_sum += num_mask |
| mix_tgt_idx += tgt_idx[1] |
|
|
| target_masks = target_masks[mix_tgt_idx].flatten(1) |
|
|
| with torch.no_grad(): |
| ious = compute_mask_iou(src_masks, target_masks) |
|
|
| tgt_iou_scores = ious |
| src_iou_scores = src_iou_scores[src_idx] |
| tgt_iou_scores = tgt_iou_scores.flatten(0) |
| src_iou_scores = src_iou_scores.flatten(0) |
|
|
| loss_objectness = self.loss_obj(src_iou_scores, tgt_iou_scores) |
| loss_dice = self.loss_dice(src_masks, target_masks) / num_instances |
| loss_mask = self.loss_mask(src_masks, target_masks) |
|
|
| return loss_objectness, loss_dice, loss_mask |
|
|
| def forward(self, outputs, batch_gt_instances, batch_img_metas, |
| batch_gt_instances_ignore): |
| |
| |
| indices = self.matcher(outputs, batch_gt_instances) |
| |
| |
| num_instances = sum(gt.labels.shape[0] for gt in batch_gt_instances) |
| num_instances = torch.as_tensor([num_instances], |
| dtype=torch.float, |
| device=next(iter( |
| outputs.values())).device) |
| num_instances = reduce_mean(num_instances).clamp_(min=1).item() |
| |
| loss_cls = self.loss_classification(outputs, batch_gt_instances, |
| indices, num_instances) |
| loss_obj, loss_dice, loss_mask = self.loss_masks_with_iou_objectness( |
| outputs, batch_gt_instances, indices, num_instances) |
|
|
| return dict( |
| loss_cls=loss_cls, |
| loss_obj=loss_obj, |
| loss_dice=loss_dice, |
| loss_mask=loss_mask) |
|
|
|
|
| @TASK_UTILS.register_module() |
| class SparseInstMatcher(nn.Module): |
|
|
| def __init__(self, alpha=0.8, beta=0.2): |
| super().__init__() |
| self.alpha = alpha |
| self.beta = beta |
| self.mask_score = dice_score |
|
|
| def forward(self, outputs, batch_gt_instances): |
| with torch.no_grad(): |
| B, N, H, W = outputs['pred_masks'].shape |
| pred_masks = outputs['pred_masks'] |
| pred_logits = outputs['pred_logits'].sigmoid() |
| device = pred_masks.device |
|
|
| tgt_ids = torch.cat([gt.labels for gt in batch_gt_instances]) |
|
|
| if tgt_ids.shape[0] == 0: |
| return [(torch.as_tensor([]).to(pred_logits), |
| torch.as_tensor([]).to(pred_logits))] * B |
| tgt_masks = torch.cat([ |
| gt.masks.to_tensor(dtype=pred_masks.dtype, device=device) |
| for gt in batch_gt_instances |
| ]) |
|
|
| tgt_masks = F.interpolate( |
| tgt_masks[:, None], |
| size=pred_masks.shape[-2:], |
| mode='bilinear', |
| align_corners=False).squeeze(1) |
|
|
| pred_masks = pred_masks.view(B * N, -1) |
| tgt_masks = tgt_masks.flatten(1) |
| with autocast(enabled=False): |
| pred_masks = pred_masks.float() |
| tgt_masks = tgt_masks.float() |
| pred_logits = pred_logits.float() |
| mask_score = self.mask_score(pred_masks, tgt_masks) |
| |
| matching_prob = pred_logits.view(B * N, -1)[:, tgt_ids] |
| C = (mask_score**self.alpha) * (matching_prob**self.beta) |
|
|
| C = C.view(B, N, -1).cpu() |
| |
| sizes = [len(gt.masks) for gt in batch_gt_instances] |
| indices = [ |
| linear_sum_assignment(c[i], maximize=True) |
| for i, c in enumerate(C.split(sizes, -1)) |
| ] |
| indices = [(torch.as_tensor(i, dtype=torch.int64), |
| torch.as_tensor(j, dtype=torch.int64)) |
| for i, j in indices] |
| return indices |
|
|