| |
| from typing import List, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from mmengine.structures import InstanceData |
| from torch import Tensor |
|
|
| from mmdet.models import BaseDetector |
| from mmdet.models.utils import unpack_gt_instances |
| from mmdet.registry import MODELS |
| from mmdet.structures import OptSampleList, SampleList |
| from mmdet.utils import ConfigType, OptConfigType |
|
|
|
|
| @torch.jit.script |
| def rescoring_mask(scores, mask_pred, masks): |
| mask_pred_ = mask_pred.float() |
| return scores * ((masks * mask_pred_).sum([1, 2]) / |
| (mask_pred_.sum([1, 2]) + 1e-6)) |
|
|
|
|
| @MODELS.register_module() |
| class SparseInst(BaseDetector): |
| """Implementation of `SparseInst <https://arxiv.org/abs/1912.02424>`_ |
| |
| Args: |
| data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of |
| :class:`DetDataPreprocessor` to process the input data. |
| Defaults to None. |
| backbone (:obj:`ConfigDict` or dict): The backbone module. |
| encoder (:obj:`ConfigDict` or dict): The encoder module. |
| decoder (:obj:`ConfigDict` or dict): The decoder module. |
| criterion (:obj:`ConfigDict` or dict, optional): The training matcher |
| and losses. Defaults to None. |
| test_cfg (:obj:`ConfigDict` or dict, optional): The testing config |
| of SparseInst. Defaults to None. |
| init_cfg (:obj:`ConfigDict` or dict, optional): the config to control |
| the initialization. Defaults to None. |
| """ |
|
|
| def __init__(self, |
| data_preprocessor: ConfigType, |
| backbone: ConfigType, |
| encoder: ConfigType, |
| decoder: ConfigType, |
| criterion: OptConfigType = None, |
| test_cfg: OptConfigType = None, |
| init_cfg: OptConfigType = None): |
| super().__init__( |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) |
|
|
| |
| self.backbone = MODELS.build(backbone) |
| |
| self.encoder = MODELS.build(encoder) |
| self.decoder = MODELS.build(decoder) |
|
|
| |
| self.criterion = MODELS.build(criterion) |
|
|
| |
| self.cls_threshold = test_cfg.score_thr |
| self.mask_threshold = test_cfg.mask_thr_binary |
|
|
| def _forward( |
| self, |
| batch_inputs: Tensor, |
| batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: |
| """Network forward process. Usually includes backbone, neck and head |
| forward without any post-processing. |
| |
| Args: |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
| |
| Returns: |
| tuple[list]: A tuple of features from ``bbox_head`` forward. |
| """ |
| x = self.backbone(batch_inputs) |
| x = self.encoder(x) |
| results = self.decoder(x) |
| return results |
|
|
| def predict(self, |
| batch_inputs: Tensor, |
| batch_data_samples: SampleList, |
| rescale: bool = True) -> SampleList: |
| """Predict results from a batch of inputs and data samples with post- |
| processing. |
| |
| Args: |
| batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
| batch_data_samples (List[:obj:`DetDataSample`]): The Data |
| Samples. It usually includes information such as |
| `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
| rescale (bool): Whether to rescale the results. |
| Defaults to True. |
| |
| Returns: |
| list[:obj:`DetDataSample`]: Detection results of the |
| input images. Each DetDataSample usually contain |
| 'pred_instances'. And the ``pred_instances`` usually |
| contains following keys. |
| |
| - scores (Tensor): Classification scores, has a shape |
| (num_instance, ) |
| - labels (Tensor): Labels of bboxes, has a shape |
| (num_instances, ). |
| - bboxes (Tensor): Has a shape (num_instances, 4), |
| the last dimension 4 arrange as (x1, y1, x2, y2). |
| """ |
| max_shape = batch_inputs.shape[-2:] |
| output = self._forward(batch_inputs) |
|
|
| pred_scores = output['pred_logits'].sigmoid() |
| pred_masks = output['pred_masks'].sigmoid() |
| pred_objectness = output['pred_scores'].sigmoid() |
| pred_scores = torch.sqrt(pred_scores * pred_objectness) |
|
|
| results_list = [] |
| for batch_idx, (scores_per_image, mask_pred_per_image, |
| datasample) in enumerate( |
| zip(pred_scores, pred_masks, batch_data_samples)): |
| result = InstanceData() |
| |
| scores, labels = scores_per_image.max(dim=-1) |
| |
| keep = scores > self.cls_threshold |
| scores = scores[keep] |
| labels = labels[keep] |
| mask_pred_per_image = mask_pred_per_image[keep] |
|
|
| if scores.size(0) == 0: |
| result.scores = scores |
| result.labels = labels |
| results_list.append(result) |
| continue |
|
|
| img_meta = datasample.metainfo |
| |
| scores = rescoring_mask(scores, |
| mask_pred_per_image > self.mask_threshold, |
| mask_pred_per_image) |
| h, w = img_meta['img_shape'][:2] |
| mask_pred_per_image = F.interpolate( |
| mask_pred_per_image.unsqueeze(1), |
| size=max_shape, |
| mode='bilinear', |
| align_corners=False)[:, :, :h, :w] |
|
|
| if rescale: |
| ori_h, ori_w = img_meta['ori_shape'][:2] |
| mask_pred_per_image = F.interpolate( |
| mask_pred_per_image, |
| size=(ori_h, ori_w), |
| mode='bilinear', |
| align_corners=False).squeeze(1) |
|
|
| mask_pred = mask_pred_per_image > self.mask_threshold |
| result.masks = mask_pred |
| result.scores = scores |
| result.labels = labels |
| |
| |
| result.bboxes = result.scores.new_zeros(len(scores), 4) |
| results_list.append(result) |
|
|
| batch_data_samples = self.add_pred_to_datasample( |
| batch_data_samples, results_list) |
| return batch_data_samples |
|
|
| def loss(self, batch_inputs: Tensor, |
| batch_data_samples: SampleList) -> Union[dict, list]: |
| """Calculate losses from a batch of inputs and data samples. |
| |
| Args: |
| batch_inputs (Tensor): Input images of shape (N, C, H, W). |
| These should usually be mean centered and std scaled. |
| batch_data_samples (list[:obj:`DetDataSample`]): The batch |
| data samples. It usually includes information such |
| as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. |
| |
| Returns: |
| dict: A dictionary of loss components. |
| """ |
| outs = self._forward(batch_inputs) |
| (batch_gt_instances, batch_gt_instances_ignore, |
| batch_img_metas) = unpack_gt_instances(batch_data_samples) |
|
|
| losses = self.criterion(outs, batch_gt_instances, batch_img_metas, |
| batch_gt_instances_ignore) |
| return losses |
|
|
| def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: |
| """Extract features. |
| |
| Args: |
| batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). |
| |
| Returns: |
| tuple[Tensor]: Multi-level features that may have |
| different resolutions. |
| """ |
| x = self.backbone(batch_inputs) |
| x = self.encoder(x) |
| return x |
|
|