| |
| |
| |
| |
|
|
| from abc import ABCMeta, abstractmethod |
|
|
| import torch |
|
|
| from ..builder import MASK_ASSIGNERS, build_match_cost |
|
|
| try: |
| from scipy.optimize import linear_sum_assignment |
| except ImportError: |
| linear_sum_assignment = None |
|
|
|
|
| class AssignResult(metaclass=ABCMeta): |
| """Collection of assign results.""" |
|
|
| def __init__(self, num_gts, gt_inds, labels): |
| self.num_gts = num_gts |
| self.gt_inds = gt_inds |
| self.labels = labels |
|
|
| @property |
| def info(self): |
| info = { |
| "num_gts": self.num_gts, |
| "gt_inds": self.gt_inds, |
| "labels": self.labels, |
| } |
| return info |
|
|
|
|
| class BaseAssigner(metaclass=ABCMeta): |
| """Base assigner that assigns boxes to ground truth boxes.""" |
|
|
| @abstractmethod |
| def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): |
| """Assign boxes to either a ground truth boxes or a negative boxes.""" |
| pass |
|
|
|
|
| @MASK_ASSIGNERS.register_module() |
| class MaskHungarianAssigner(BaseAssigner): |
| """Computes one-to-one matching between predictions and ground truth for |
| mask. |
| |
| This class computes an assignment between the targets and the predictions |
| based on the costs. The costs are weighted sum of three components: |
| classification cost, regression L1 cost and regression iou cost. The |
| targets don't include the no_object, so generally there are more |
| predictions than targets. After the one-to-one matching, the un-matched |
| are treated as backgrounds. Thus each query prediction will be assigned |
| with `0` or a positive integer indicating the ground truth index: |
| |
| - 0: negative sample, no assigned gt |
| - positive integer: positive sample, index (1-based) of assigned gt |
| |
| Args: |
| cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. |
| mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. |
| dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. |
| """ |
|
|
| def __init__( |
| self, |
| cls_cost=dict(type="ClassificationCost", weight=1.0), |
| dice_cost=dict(type="DiceCost", weight=1.0), |
| mask_cost=dict(type="MaskFocalCost", weight=1.0), |
| ): |
| self.cls_cost = build_match_cost(cls_cost) |
| self.dice_cost = build_match_cost(dice_cost) |
| self.mask_cost = build_match_cost(mask_cost) |
|
|
| def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): |
| """Computes one-to-one matching based on the weighted costs. |
| |
| This method assign each query prediction to a ground truth or |
| background. The `assigned_gt_inds` with -1 means don't care, |
| 0 means negative sample, and positive number is the index (1-based) |
| of assigned gt. |
| The assignment is done in the following steps, the order matters. |
| |
| 1. assign every prediction to -1 |
| 2. compute the weighted costs |
| 3. do Hungarian matching on CPU based on the costs |
| 4. assign all to 0 (background) first, then for each matched pair |
| between predictions and gts, treat this prediction as foreground |
| and assign the corresponding gt index (plus 1) to it. |
| |
| Args: |
| mask_pred (Tensor): Predicted mask, shape [num_query, h, w] |
| cls_pred (Tensor): Predicted classification logits, shape |
| [num_query, num_class]. |
| gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. |
| gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). |
| img_meta (dict): Meta information for current image. |
| gt_masks_ignore (Tensor, optional): Ground truth masks that are |
| labelled as `ignored`. Default None. |
| eps (int | float, optional): A value added to the denominator for |
| numerical stability. Default 1e-7. |
| |
| Returns: |
| :obj:`AssignResult`: The assigned result. |
| """ |
| assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." |
| num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] |
|
|
| |
| assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) |
| assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) |
| if num_gts == 0 or num_queries == 0: |
| |
| if num_gts == 0: |
| |
| assigned_gt_inds[:] = 0 |
| return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) |
|
|
| |
| |
| if self.cls_cost.weight != 0 and cls_pred is not None: |
| cls_cost = self.cls_cost(cls_pred, gt_labels) |
| else: |
| cls_cost = 0 |
|
|
| if self.mask_cost.weight != 0: |
| |
| |
| |
| mask_cost = self.mask_cost(mask_pred, gt_masks) |
| else: |
| mask_cost = 0 |
|
|
| if self.dice_cost.weight != 0: |
| dice_cost = self.dice_cost(mask_pred, gt_masks) |
| else: |
| dice_cost = 0 |
| cost = cls_cost + mask_cost + dice_cost |
|
|
| |
| cost = cost.detach().cpu() |
| if linear_sum_assignment is None: |
| raise ImportError('Please run "pip install scipy" ' "to install scipy first.") |
|
|
| matched_row_inds, matched_col_inds = linear_sum_assignment(cost) |
| matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) |
| matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) |
|
|
| |
| |
| assigned_gt_inds[:] = 0 |
| |
| assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 |
| assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] |
| return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) |
|
|