| |
| import copy |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule, Scale |
| from mmengine.config import ConfigDict |
| from mmengine.model import BaseModule, kaiming_init |
| from mmengine.structures import InstanceData |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
| from mmdet.structures.bbox import cat_boxes |
| from mmdet.utils import (ConfigType, InstanceList, MultiConfig, OptConfigType, |
| OptInstanceList, reduce_mean) |
| from ..task_modules.prior_generators import MlvlPointGenerator |
| from ..utils import (aligned_bilinear, filter_scores_and_topk, multi_apply, |
| relative_coordinate_maps, select_single_mlvl) |
| from ..utils.misc import empty_instances |
| from .base_mask_head import BaseMaskHead |
| from .fcos_head import FCOSHead |
|
|
| INF = 1e8 |
|
|
|
|
| @MODELS.register_module() |
| class CondInstBboxHead(FCOSHead): |
| """CondInst box head used in https://arxiv.org/abs/1904.02689. |
| |
| Note that CondInst Bbox Head is a extension of FCOS head. |
| Two differences are described as follows: |
| |
| 1. CondInst box head predicts a set of params for each instance. |
| 2. CondInst box head return the pos_gt_inds and pos_inds. |
| |
| Args: |
| num_params (int): Number of params for instance segmentation. |
| """ |
|
|
| def __init__(self, *args, num_params: int = 169, **kwargs) -> None: |
| self.num_params = num_params |
| super().__init__(*args, **kwargs) |
|
|
| def _init_layers(self) -> None: |
| """Initialize layers of the head.""" |
| super()._init_layers() |
| self.controller = nn.Conv2d( |
| self.feat_channels, self.num_params, 3, padding=1) |
|
|
| def forward_single(self, x: Tensor, scale: Scale, |
| stride: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
| """Forward features of a single scale level. |
| |
| Args: |
| x (Tensor): FPN feature maps of the specified stride. |
| scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize |
| the bbox prediction. |
| stride (int): The corresponding stride for feature maps, only |
| used to normalize the bbox prediction when self.norm_on_bbox |
| is True. |
| |
| Returns: |
| tuple: scores for each class, bbox predictions, centerness |
| predictions and param predictions of input feature maps. |
| """ |
| cls_score, bbox_pred, cls_feat, reg_feat = \ |
| super(FCOSHead, self).forward_single(x) |
| if self.centerness_on_reg: |
| centerness = self.conv_centerness(reg_feat) |
| else: |
| centerness = self.conv_centerness(cls_feat) |
| |
| |
| bbox_pred = scale(bbox_pred).float() |
| if self.norm_on_bbox: |
| |
| |
| |
| bbox_pred = bbox_pred.clamp(min=0) |
| if not self.training: |
| bbox_pred *= stride |
| else: |
| bbox_pred = bbox_pred.exp() |
| param_pred = self.controller(reg_feat) |
| return cls_score, bbox_pred, centerness, param_pred |
|
|
| def loss_by_feat( |
| self, |
| cls_scores: List[Tensor], |
| bbox_preds: List[Tensor], |
| centernesses: List[Tensor], |
| param_preds: List[Tensor], |
| batch_gt_instances: InstanceList, |
| batch_img_metas: List[dict], |
| batch_gt_instances_ignore: OptInstanceList = None |
| ) -> Dict[str, Tensor]: |
| """Calculate the loss based on the features extracted by the detection |
| head. |
| |
| Args: |
| cls_scores (list[Tensor]): Box scores for each scale level, |
| each is a 4D-tensor, the channel number is |
| num_points * num_classes. |
| bbox_preds (list[Tensor]): Box energies / deltas for each scale |
| level, each is a 4D-tensor, the channel number is |
| num_points * 4. |
| centernesses (list[Tensor]): centerness for each scale level, each |
| is a 4D-tensor, the channel number is num_points * 1. |
| param_preds (List[Tensor]): param_pred for each scale level, each |
| is a 4D-tensor, the channel number is num_params. |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
| gt_instance. It usually includes ``bboxes`` and ``labels`` |
| attributes. |
| batch_img_metas (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): |
| Batch of gt_instances_ignore. It includes ``bboxes`` attribute |
| data that is ignored during training and testing. |
| Defaults to None. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of loss components. |
| """ |
| assert len(cls_scores) == len(bbox_preds) == len(centernesses) |
| featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
| |
| all_level_points_strides = self.prior_generator.grid_priors( |
| featmap_sizes, |
| dtype=bbox_preds[0].dtype, |
| device=bbox_preds[0].device, |
| with_stride=True) |
| all_level_points = [i[:, :2] for i in all_level_points_strides] |
| all_level_strides = [i[:, 2] for i in all_level_points_strides] |
| labels, bbox_targets, pos_inds_list, pos_gt_inds_list = \ |
| self.get_targets(all_level_points, batch_gt_instances) |
|
|
| num_imgs = cls_scores[0].size(0) |
| |
| flatten_cls_scores = [ |
| cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) |
| for cls_score in cls_scores |
| ] |
| flatten_bbox_preds = [ |
| bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) |
| for bbox_pred in bbox_preds |
| ] |
| flatten_centerness = [ |
| centerness.permute(0, 2, 3, 1).reshape(-1) |
| for centerness in centernesses |
| ] |
| flatten_cls_scores = torch.cat(flatten_cls_scores) |
| flatten_bbox_preds = torch.cat(flatten_bbox_preds) |
| flatten_centerness = torch.cat(flatten_centerness) |
| flatten_labels = torch.cat(labels) |
| flatten_bbox_targets = torch.cat(bbox_targets) |
| |
| flatten_points = torch.cat( |
| [points.repeat(num_imgs, 1) for points in all_level_points]) |
|
|
| |
| bg_class_ind = self.num_classes |
| pos_inds = ((flatten_labels >= 0) |
| & (flatten_labels < bg_class_ind)).nonzero().reshape(-1) |
| num_pos = torch.tensor( |
| len(pos_inds), dtype=torch.float, device=bbox_preds[0].device) |
| num_pos = max(reduce_mean(num_pos), 1.0) |
| loss_cls = self.loss_cls( |
| flatten_cls_scores, flatten_labels, avg_factor=num_pos) |
|
|
| pos_bbox_preds = flatten_bbox_preds[pos_inds] |
| pos_centerness = flatten_centerness[pos_inds] |
| pos_bbox_targets = flatten_bbox_targets[pos_inds] |
| pos_centerness_targets = self.centerness_target(pos_bbox_targets) |
| |
| centerness_denorm = max( |
| reduce_mean(pos_centerness_targets.sum().detach()), 1e-6) |
|
|
| if len(pos_inds) > 0: |
| pos_points = flatten_points[pos_inds] |
| pos_decoded_bbox_preds = self.bbox_coder.decode( |
| pos_points, pos_bbox_preds) |
| pos_decoded_target_preds = self.bbox_coder.decode( |
| pos_points, pos_bbox_targets) |
| loss_bbox = self.loss_bbox( |
| pos_decoded_bbox_preds, |
| pos_decoded_target_preds, |
| weight=pos_centerness_targets, |
| avg_factor=centerness_denorm) |
| loss_centerness = self.loss_centerness( |
| pos_centerness, pos_centerness_targets, avg_factor=num_pos) |
| else: |
| loss_bbox = pos_bbox_preds.sum() |
| loss_centerness = pos_centerness.sum() |
|
|
| self._raw_positive_infos.update(cls_scores=cls_scores) |
| self._raw_positive_infos.update(centernesses=centernesses) |
| self._raw_positive_infos.update(param_preds=param_preds) |
| self._raw_positive_infos.update(all_level_points=all_level_points) |
| self._raw_positive_infos.update(all_level_strides=all_level_strides) |
| self._raw_positive_infos.update(pos_gt_inds_list=pos_gt_inds_list) |
| self._raw_positive_infos.update(pos_inds_list=pos_inds_list) |
|
|
| return dict( |
| loss_cls=loss_cls, |
| loss_bbox=loss_bbox, |
| loss_centerness=loss_centerness) |
|
|
| def get_targets( |
| self, points: List[Tensor], batch_gt_instances: InstanceList |
| ) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]: |
| """Compute regression, classification and centerness targets for points |
| in multiple images. |
| |
| Args: |
| points (list[Tensor]): Points of each fpn level, each has shape |
| (num_points, 2). |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
| gt_instance. It usually includes ``bboxes`` and ``labels`` |
| attributes. |
| |
| Returns: |
| tuple: Targets of each level. |
| |
| - concat_lvl_labels (list[Tensor]): Labels of each level. |
| - concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \ |
| level. |
| - pos_inds_list (list[Tensor]): pos_inds of each image. |
| - pos_gt_inds_list (List[Tensor]): pos_gt_inds of each image. |
| """ |
| assert len(points) == len(self.regress_ranges) |
| num_levels = len(points) |
| |
| expanded_regress_ranges = [ |
| points[i].new_tensor(self.regress_ranges[i])[None].expand_as( |
| points[i]) for i in range(num_levels) |
| ] |
| |
| concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0) |
| concat_points = torch.cat(points, dim=0) |
|
|
| |
| num_points = [center.size(0) for center in points] |
|
|
| |
| labels_list, bbox_targets_list, pos_inds_list, pos_gt_inds_list = \ |
| multi_apply( |
| self._get_targets_single, |
| batch_gt_instances, |
| points=concat_points, |
| regress_ranges=concat_regress_ranges, |
| num_points_per_lvl=num_points) |
|
|
| |
| labels_list = [labels.split(num_points, 0) for labels in labels_list] |
| bbox_targets_list = [ |
| bbox_targets.split(num_points, 0) |
| for bbox_targets in bbox_targets_list |
| ] |
|
|
| |
| concat_lvl_labels = [] |
| concat_lvl_bbox_targets = [] |
| for i in range(num_levels): |
| concat_lvl_labels.append( |
| torch.cat([labels[i] for labels in labels_list])) |
| bbox_targets = torch.cat( |
| [bbox_targets[i] for bbox_targets in bbox_targets_list]) |
| if self.norm_on_bbox: |
| bbox_targets = bbox_targets / self.strides[i] |
| concat_lvl_bbox_targets.append(bbox_targets) |
| return (concat_lvl_labels, concat_lvl_bbox_targets, pos_inds_list, |
| pos_gt_inds_list) |
|
|
| def _get_targets_single( |
| self, gt_instances: InstanceData, points: Tensor, |
| regress_ranges: Tensor, num_points_per_lvl: List[int] |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
| """Compute regression and classification targets for a single image.""" |
| num_points = points.size(0) |
| num_gts = len(gt_instances) |
| gt_bboxes = gt_instances.bboxes |
| gt_labels = gt_instances.labels |
| gt_masks = gt_instances.get('masks', None) |
|
|
| if num_gts == 0: |
| return gt_labels.new_full((num_points,), self.num_classes), \ |
| gt_bboxes.new_zeros((num_points, 4)), \ |
| gt_bboxes.new_zeros((0,), dtype=torch.int64), \ |
| gt_bboxes.new_zeros((0,), dtype=torch.int64) |
|
|
| areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * ( |
| gt_bboxes[:, 3] - gt_bboxes[:, 1]) |
| |
| |
| areas = areas[None].repeat(num_points, 1) |
| regress_ranges = regress_ranges[:, None, :].expand( |
| num_points, num_gts, 2) |
| gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) |
| xs, ys = points[:, 0], points[:, 1] |
| xs = xs[:, None].expand(num_points, num_gts) |
| ys = ys[:, None].expand(num_points, num_gts) |
|
|
| left = xs - gt_bboxes[..., 0] |
| right = gt_bboxes[..., 2] - xs |
| top = ys - gt_bboxes[..., 1] |
| bottom = gt_bboxes[..., 3] - ys |
| bbox_targets = torch.stack((left, top, right, bottom), -1) |
|
|
| if self.center_sampling: |
| |
| radius = self.center_sample_radius |
| |
| |
| if gt_masks is None: |
| center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2 |
| center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2 |
| else: |
| h, w = gt_masks.height, gt_masks.width |
| masks = gt_masks.to_tensor( |
| dtype=torch.bool, device=gt_bboxes.device) |
| yys = torch.arange( |
| 0, h, dtype=torch.float32, device=masks.device) |
| xxs = torch.arange( |
| 0, w, dtype=torch.float32, device=masks.device) |
| |
| |
| m00 = masks.sum(dim=-1).sum(dim=-1).clamp(min=1e-6) |
| m10 = (masks * xxs).sum(dim=-1).sum(dim=-1) |
| m01 = (masks * yys[:, None]).sum(dim=-1).sum(dim=-1) |
| center_xs = m10 / m00 |
| center_ys = m01 / m00 |
|
|
| center_xs = center_xs[None].expand(num_points, num_gts) |
| center_ys = center_ys[None].expand(num_points, num_gts) |
| center_gts = torch.zeros_like(gt_bboxes) |
| stride = center_xs.new_zeros(center_xs.shape) |
|
|
| |
| lvl_begin = 0 |
| for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): |
| lvl_end = lvl_begin + num_points_lvl |
| stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius |
| lvl_begin = lvl_end |
|
|
| x_mins = center_xs - stride |
| y_mins = center_ys - stride |
| x_maxs = center_xs + stride |
| y_maxs = center_ys + stride |
| center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0], |
| x_mins, gt_bboxes[..., 0]) |
| center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1], |
| y_mins, gt_bboxes[..., 1]) |
| center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2], |
| gt_bboxes[..., 2], x_maxs) |
| center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3], |
| gt_bboxes[..., 3], y_maxs) |
|
|
| cb_dist_left = xs - center_gts[..., 0] |
| cb_dist_right = center_gts[..., 2] - xs |
| cb_dist_top = ys - center_gts[..., 1] |
| cb_dist_bottom = center_gts[..., 3] - ys |
| center_bbox = torch.stack( |
| (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1) |
| inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 |
| else: |
| |
| inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 |
|
|
| |
| max_regress_distance = bbox_targets.max(-1)[0] |
| inside_regress_range = ( |
| (max_regress_distance >= regress_ranges[..., 0]) |
| & (max_regress_distance <= regress_ranges[..., 1])) |
|
|
| |
| |
| areas[inside_gt_bbox_mask == 0] = INF |
| areas[inside_regress_range == 0] = INF |
| min_area, min_area_inds = areas.min(dim=1) |
|
|
| labels = gt_labels[min_area_inds] |
| labels[min_area == INF] = self.num_classes |
| bbox_targets = bbox_targets[range(num_points), min_area_inds] |
|
|
| |
| bg_class_ind = self.num_classes |
| pos_inds = ((labels >= 0) |
| & (labels < bg_class_ind)).nonzero().reshape(-1) |
| pos_gt_inds = min_area_inds[labels < self.num_classes] |
| return labels, bbox_targets, pos_inds, pos_gt_inds |
|
|
| def get_positive_infos(self) -> InstanceList: |
| """Get positive information from sampling results. |
| |
| Returns: |
| list[:obj:`InstanceData`]: Positive information of each image, |
| usually including positive bboxes, positive labels, positive |
| priors, etc. |
| """ |
| assert len(self._raw_positive_infos) > 0 |
|
|
| pos_gt_inds_list = self._raw_positive_infos['pos_gt_inds_list'] |
| pos_inds_list = self._raw_positive_infos['pos_inds_list'] |
| num_imgs = len(pos_gt_inds_list) |
|
|
| cls_score_list = [] |
| centerness_list = [] |
| param_pred_list = [] |
| point_list = [] |
| stride_list = [] |
| for cls_score_per_lvl, centerness_per_lvl, param_pred_per_lvl,\ |
| point_per_lvl, stride_per_lvl in \ |
| zip(self._raw_positive_infos['cls_scores'], |
| self._raw_positive_infos['centernesses'], |
| self._raw_positive_infos['param_preds'], |
| self._raw_positive_infos['all_level_points'], |
| self._raw_positive_infos['all_level_strides']): |
| cls_score_per_lvl = \ |
| cls_score_per_lvl.permute( |
| 0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) |
| centerness_per_lvl = \ |
| centerness_per_lvl.permute( |
| 0, 2, 3, 1).reshape(num_imgs, -1, 1) |
| param_pred_per_lvl = \ |
| param_pred_per_lvl.permute( |
| 0, 2, 3, 1).reshape(num_imgs, -1, self.num_params) |
| point_per_lvl = point_per_lvl.unsqueeze(0).repeat(num_imgs, 1, 1) |
| stride_per_lvl = stride_per_lvl.unsqueeze(0).repeat(num_imgs, 1) |
|
|
| cls_score_list.append(cls_score_per_lvl) |
| centerness_list.append(centerness_per_lvl) |
| param_pred_list.append(param_pred_per_lvl) |
| point_list.append(point_per_lvl) |
| stride_list.append(stride_per_lvl) |
| cls_scores = torch.cat(cls_score_list, dim=1) |
| centernesses = torch.cat(centerness_list, dim=1) |
| param_preds = torch.cat(param_pred_list, dim=1) |
| all_points = torch.cat(point_list, dim=1) |
| all_strides = torch.cat(stride_list, dim=1) |
|
|
| positive_infos = [] |
| for i, (pos_gt_inds, |
| pos_inds) in enumerate(zip(pos_gt_inds_list, pos_inds_list)): |
| pos_info = InstanceData() |
| pos_info.points = all_points[i][pos_inds] |
| pos_info.strides = all_strides[i][pos_inds] |
| pos_info.scores = cls_scores[i][pos_inds] |
| pos_info.centernesses = centernesses[i][pos_inds] |
| pos_info.param_preds = param_preds[i][pos_inds] |
| pos_info.pos_assigned_gt_inds = pos_gt_inds |
| pos_info.pos_inds = pos_inds |
| positive_infos.append(pos_info) |
| return positive_infos |
|
|
| def predict_by_feat(self, |
| cls_scores: List[Tensor], |
| bbox_preds: List[Tensor], |
| score_factors: Optional[List[Tensor]] = None, |
| param_preds: Optional[List[Tensor]] = None, |
| batch_img_metas: Optional[List[dict]] = None, |
| cfg: Optional[ConfigDict] = None, |
| rescale: bool = False, |
| with_nms: bool = True) -> InstanceList: |
| """Transform a batch of output features extracted from the head into |
| bbox results. |
| |
| Note: When score_factors is not None, the cls_scores are |
| usually multiplied by it then obtain the real score used in NMS, |
| such as CenterNess in FCOS, IoU branch in ATSS. |
| |
| Args: |
| cls_scores (list[Tensor]): Classification scores for all |
| scale levels, each is a 4D-tensor, has shape |
| (batch_size, num_priors * num_classes, H, W). |
| bbox_preds (list[Tensor]): Box energies / deltas for all |
| scale levels, each is a 4D-tensor, has shape |
| (batch_size, num_priors * 4, H, W). |
| score_factors (list[Tensor], optional): Score factor for |
| all scale level, each is a 4D-tensor, has shape |
| (batch_size, num_priors * 1, H, W). Defaults to None. |
| param_preds (list[Tensor], optional): Params for all scale |
| level, each is a 4D-tensor, has shape |
| (batch_size, num_priors * num_params, H, W) |
| batch_img_metas (list[dict], Optional): Batch image meta info. |
| Defaults to None. |
| cfg (ConfigDict, optional): Test / postprocessing |
| configuration, if None, test_cfg would be used. |
| Defaults to None. |
| rescale (bool): If True, return boxes in original image space. |
| Defaults to False. |
| with_nms (bool): If True, do nms before return boxes. |
| Defaults to True. |
| |
| Returns: |
| list[:obj:`InstanceData`]: Object detection results of each image |
| after the post process. Each item usually contains following keys. |
| |
| - scores (Tensor): Classification scores, has a shape |
| (num_instance, ) |
| - labels (Tensor): Labels of bboxes, has a shape |
| (num_instances, ). |
| - bboxes (Tensor): Has a shape (num_instances, 4), |
| the last dimension 4 arrange as (x1, y1, x2, y2). |
| """ |
| assert len(cls_scores) == len(bbox_preds) |
|
|
| if score_factors is None: |
| |
| with_score_factors = False |
| else: |
| |
| with_score_factors = True |
| assert len(cls_scores) == len(score_factors) |
|
|
| num_levels = len(cls_scores) |
|
|
| featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] |
| all_level_points_strides = self.prior_generator.grid_priors( |
| featmap_sizes, |
| dtype=bbox_preds[0].dtype, |
| device=bbox_preds[0].device, |
| with_stride=True) |
| all_level_points = [i[:, :2] for i in all_level_points_strides] |
| all_level_strides = [i[:, 2] for i in all_level_points_strides] |
|
|
| result_list = [] |
|
|
| for img_id in range(len(batch_img_metas)): |
| img_meta = batch_img_metas[img_id] |
| cls_score_list = select_single_mlvl( |
| cls_scores, img_id, detach=True) |
| bbox_pred_list = select_single_mlvl( |
| bbox_preds, img_id, detach=True) |
| if with_score_factors: |
| score_factor_list = select_single_mlvl( |
| score_factors, img_id, detach=True) |
| else: |
| score_factor_list = [None for _ in range(num_levels)] |
| param_pred_list = select_single_mlvl( |
| param_preds, img_id, detach=True) |
|
|
| results = self._predict_by_feat_single( |
| cls_score_list=cls_score_list, |
| bbox_pred_list=bbox_pred_list, |
| score_factor_list=score_factor_list, |
| param_pred_list=param_pred_list, |
| mlvl_points=all_level_points, |
| mlvl_strides=all_level_strides, |
| img_meta=img_meta, |
| cfg=cfg, |
| rescale=rescale, |
| with_nms=with_nms) |
| result_list.append(results) |
| return result_list |
|
|
| def _predict_by_feat_single(self, |
| cls_score_list: List[Tensor], |
| bbox_pred_list: List[Tensor], |
| score_factor_list: List[Tensor], |
| param_pred_list: List[Tensor], |
| mlvl_points: List[Tensor], |
| mlvl_strides: List[Tensor], |
| img_meta: dict, |
| cfg: ConfigDict, |
| rescale: bool = False, |
| with_nms: bool = True) -> InstanceData: |
| """Transform a single image's features extracted from the head into |
| bbox results. |
| |
| Args: |
| cls_score_list (list[Tensor]): Box scores from all scale |
| levels of a single image, each item has shape |
| (num_priors * num_classes, H, W). |
| bbox_pred_list (list[Tensor]): Box energies / deltas from |
| all scale levels of a single image, each item has shape |
| (num_priors * 4, H, W). |
| score_factor_list (list[Tensor]): Score factor from all scale |
| levels of a single image, each item has shape |
| (num_priors * 1, H, W). |
| param_pred_list (List[Tensor]): Param predition from all scale |
| levels of a single image, each item has shape |
| (num_priors * num_params, H, W). |
| mlvl_points (list[Tensor]): Each element in the list is |
| the priors of a single level in feature pyramid. |
| It has shape (num_priors, 2) |
| mlvl_strides (List[Tensor]): Each element in the list is |
| the stride of a single level in feature pyramid. |
| It has shape (num_priors, 1) |
| img_meta (dict): Image meta info. |
| cfg (mmengine.Config): Test / postprocessing configuration, |
| if None, test_cfg would be used. |
| rescale (bool): If True, return boxes in original image space. |
| Defaults to False. |
| with_nms (bool): If True, do nms before return boxes. |
| Defaults to True. |
| |
| Returns: |
| :obj:`InstanceData`: Detection results of each image |
| after the post process. |
| Each item usually contains following keys. |
| |
| - scores (Tensor): Classification scores, has a shape |
| (num_instance, ) |
| - labels (Tensor): Labels of bboxes, has a shape |
| (num_instances, ). |
| - bboxes (Tensor): Has a shape (num_instances, 4), |
| the last dimension 4 arrange as (x1, y1, x2, y2). |
| """ |
| if score_factor_list[0] is None: |
| |
| with_score_factors = False |
| else: |
| |
| with_score_factors = True |
|
|
| cfg = self.test_cfg if cfg is None else cfg |
| cfg = copy.deepcopy(cfg) |
| img_shape = img_meta['img_shape'] |
| nms_pre = cfg.get('nms_pre', -1) |
|
|
| mlvl_bbox_preds = [] |
| mlvl_param_preds = [] |
| mlvl_valid_points = [] |
| mlvl_valid_strides = [] |
| mlvl_scores = [] |
| mlvl_labels = [] |
| if with_score_factors: |
| mlvl_score_factors = [] |
| else: |
| mlvl_score_factors = None |
| for level_idx, (cls_score, bbox_pred, score_factor, |
| param_pred, points, strides) in \ |
| enumerate(zip(cls_score_list, bbox_pred_list, |
| score_factor_list, param_pred_list, |
| mlvl_points, mlvl_strides)): |
|
|
| assert cls_score.size()[-2:] == bbox_pred.size()[-2:] |
|
|
| dim = self.bbox_coder.encode_size |
| bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) |
| if with_score_factors: |
| score_factor = score_factor.permute(1, 2, |
| 0).reshape(-1).sigmoid() |
| cls_score = cls_score.permute(1, 2, |
| 0).reshape(-1, self.cls_out_channels) |
| if self.use_sigmoid_cls: |
| scores = cls_score.sigmoid() |
| else: |
| |
| |
| |
| scores = cls_score.softmax(-1)[:, :-1] |
|
|
| param_pred = param_pred.permute(1, 2, |
| 0).reshape(-1, self.num_params) |
|
|
| |
| |
| |
| |
| |
| score_thr = cfg.get('score_thr', 0) |
|
|
| results = filter_scores_and_topk( |
| scores, score_thr, nms_pre, |
| dict( |
| bbox_pred=bbox_pred, |
| param_pred=param_pred, |
| points=points, |
| strides=strides)) |
| scores, labels, keep_idxs, filtered_results = results |
|
|
| bbox_pred = filtered_results['bbox_pred'] |
| param_pred = filtered_results['param_pred'] |
| points = filtered_results['points'] |
| strides = filtered_results['strides'] |
|
|
| if with_score_factors: |
| score_factor = score_factor[keep_idxs] |
|
|
| mlvl_bbox_preds.append(bbox_pred) |
| mlvl_param_preds.append(param_pred) |
| mlvl_valid_points.append(points) |
| mlvl_valid_strides.append(strides) |
| mlvl_scores.append(scores) |
| mlvl_labels.append(labels) |
|
|
| if with_score_factors: |
| mlvl_score_factors.append(score_factor) |
|
|
| bbox_pred = torch.cat(mlvl_bbox_preds) |
| priors = cat_boxes(mlvl_valid_points) |
| bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape) |
|
|
| results = InstanceData() |
| results.bboxes = bboxes |
| results.scores = torch.cat(mlvl_scores) |
| results.labels = torch.cat(mlvl_labels) |
| results.param_preds = torch.cat(mlvl_param_preds) |
| results.points = torch.cat(mlvl_valid_points) |
| results.strides = torch.cat(mlvl_valid_strides) |
| if with_score_factors: |
| results.score_factors = torch.cat(mlvl_score_factors) |
|
|
| return self._bbox_post_process( |
| results=results, |
| cfg=cfg, |
| rescale=rescale, |
| with_nms=with_nms, |
| img_meta=img_meta) |
|
|
|
|
| class MaskFeatModule(BaseModule): |
| """CondInst mask feature map branch used in \ |
| https://arxiv.org/abs/1904.02689. |
| |
| Args: |
| in_channels (int): Number of channels in the input feature map. |
| feat_channels (int): Number of hidden channels of the mask feature |
| map branch. |
| start_level (int): The starting feature map level from RPN that |
| will be used to predict the mask feature map. |
| end_level (int): The ending feature map level from rpn that |
| will be used to predict the mask feature map. |
| out_channels (int): Number of output channels of the mask feature |
| map branch. This is the channel count of the mask |
| feature map that to be dynamically convolved with the predicted |
| kernel. |
| mask_stride (int): Downsample factor of the mask feature map output. |
| Defaults to 4. |
| num_stacked_convs (int): Number of convs in mask feature branch. |
| conv_cfg (dict): Config dict for convolution layer. Default: None. |
| norm_cfg (dict): Config dict for normalization layer. Default: None. |
| init_cfg (dict or list[dict], optional): Initialization config dict. |
| """ |
|
|
| def __init__(self, |
| in_channels: int, |
| feat_channels: int, |
| start_level: int, |
| end_level: int, |
| out_channels: int, |
| mask_stride: int = 4, |
| num_stacked_convs: int = 4, |
| conv_cfg: OptConfigType = None, |
| norm_cfg: OptConfigType = None, |
| init_cfg: MultiConfig = [ |
| dict(type='Normal', layer='Conv2d', std=0.01) |
| ], |
| **kwargs) -> None: |
| super().__init__(init_cfg=init_cfg) |
| self.in_channels = in_channels |
| self.feat_channels = feat_channels |
| self.start_level = start_level |
| self.end_level = end_level |
| self.mask_stride = mask_stride |
| self.num_stacked_convs = num_stacked_convs |
| assert start_level >= 0 and end_level >= start_level |
| self.out_channels = out_channels |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize layers of the head.""" |
| self.convs_all_levels = nn.ModuleList() |
| for i in range(self.start_level, self.end_level + 1): |
| convs_per_level = nn.Sequential() |
| convs_per_level.add_module( |
| f'conv{i}', |
| ConvModule( |
| self.in_channels, |
| self.feat_channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| inplace=False, |
| bias=False)) |
| self.convs_all_levels.append(convs_per_level) |
|
|
| conv_branch = [] |
| for _ in range(self.num_stacked_convs): |
| conv_branch.append( |
| ConvModule( |
| self.feat_channels, |
| self.feat_channels, |
| 3, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg, |
| bias=False)) |
| self.conv_branch = nn.Sequential(*conv_branch) |
|
|
| self.conv_pred = nn.Conv2d( |
| self.feat_channels, self.out_channels, 1, stride=1) |
|
|
| def init_weights(self) -> None: |
| """Initialize weights of the head.""" |
| super().init_weights() |
| kaiming_init(self.convs_all_levels, a=1, distribution='uniform') |
| kaiming_init(self.conv_branch, a=1, distribution='uniform') |
| kaiming_init(self.conv_pred, a=1, distribution='uniform') |
|
|
| def forward(self, x: Tuple[Tensor]) -> Tensor: |
| """Forward features from the upstream network. |
| |
| Args: |
| x (tuple[Tensor]): Features from the upstream network, each is |
| a 4D-tensor. |
| |
| Returns: |
| Tensor: The predicted mask feature map. |
| """ |
| inputs = x[self.start_level:self.end_level + 1] |
| assert len(inputs) == (self.end_level - self.start_level + 1) |
| feature_add_all_level = self.convs_all_levels[0](inputs[0]) |
| target_h, target_w = feature_add_all_level.size()[2:] |
| for i in range(1, len(inputs)): |
| input_p = inputs[i] |
| x_p = self.convs_all_levels[i](input_p) |
| h, w = x_p.size()[2:] |
| factor_h = target_h // h |
| factor_w = target_w // w |
| assert factor_h == factor_w |
| feature_per_level = aligned_bilinear(x_p, factor_h) |
| feature_add_all_level = feature_add_all_level + \ |
| feature_per_level |
|
|
| feature_add_all_level = self.conv_branch(feature_add_all_level) |
| feature_pred = self.conv_pred(feature_add_all_level) |
| return feature_pred |
|
|
|
|
| @MODELS.register_module() |
| class CondInstMaskHead(BaseMaskHead): |
| """CondInst mask head used in https://arxiv.org/abs/1904.02689. |
| |
| This head outputs the mask for CondInst. |
| |
| Args: |
| mask_feature_head (dict): Config of CondInstMaskFeatHead. |
| num_layers (int): Number of dynamic conv layers. |
| feat_channels (int): Number of channels in the dynamic conv. |
| mask_out_stride (int): The stride of the mask feat. |
| size_of_interest (int): The size of the region used in rel coord. |
| max_masks_to_train (int): Maximum number of masks to train for |
| each image. |
| loss_segm (:obj:`ConfigDict` or dict, optional): Config of |
| segmentation loss. |
| train_cfg (:obj:`ConfigDict` or dict, optional): Training config |
| of head. |
| test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of |
| head. |
| """ |
|
|
| def __init__(self, |
| mask_feature_head: ConfigType, |
| num_layers: int = 3, |
| feat_channels: int = 8, |
| mask_out_stride: int = 4, |
| size_of_interest: int = 8, |
| max_masks_to_train: int = -1, |
| topk_masks_per_img: int = -1, |
| loss_mask: ConfigType = None, |
| train_cfg: OptConfigType = None, |
| test_cfg: OptConfigType = None) -> None: |
| super().__init__() |
| self.mask_feature_head = MaskFeatModule(**mask_feature_head) |
| self.mask_feat_stride = self.mask_feature_head.mask_stride |
| self.in_channels = self.mask_feature_head.out_channels |
| self.num_layers = num_layers |
| self.feat_channels = feat_channels |
| self.size_of_interest = size_of_interest |
| self.mask_out_stride = mask_out_stride |
| self.max_masks_to_train = max_masks_to_train |
| self.topk_masks_per_img = topk_masks_per_img |
| self.prior_generator = MlvlPointGenerator([self.mask_feat_stride]) |
|
|
| self.train_cfg = train_cfg |
| self.test_cfg = test_cfg |
| self.loss_mask = MODELS.build(loss_mask) |
| self._init_layers() |
|
|
| def _init_layers(self) -> None: |
| """Initialize layers of the head.""" |
| weight_nums, bias_nums = [], [] |
| for i in range(self.num_layers): |
| if i == 0: |
| weight_nums.append((self.in_channels + 2) * self.feat_channels) |
| bias_nums.append(self.feat_channels) |
| elif i == self.num_layers - 1: |
| weight_nums.append(self.feat_channels * 1) |
| bias_nums.append(1) |
| else: |
| weight_nums.append(self.feat_channels * self.feat_channels) |
| bias_nums.append(self.feat_channels) |
|
|
| self.weight_nums = weight_nums |
| self.bias_nums = bias_nums |
| self.num_params = sum(weight_nums) + sum(bias_nums) |
|
|
| def parse_dynamic_params( |
| self, params: Tensor) -> Tuple[List[Tensor], List[Tensor]]: |
| """parse the dynamic params for dynamic conv.""" |
| num_insts = params.size(0) |
| params_splits = list( |
| torch.split_with_sizes( |
| params, self.weight_nums + self.bias_nums, dim=1)) |
| weight_splits = params_splits[:self.num_layers] |
| bias_splits = params_splits[self.num_layers:] |
| for i in range(self.num_layers): |
| if i < self.num_layers - 1: |
| weight_splits[i] = weight_splits[i].reshape( |
| num_insts * self.in_channels, -1, 1, 1) |
| bias_splits[i] = bias_splits[i].reshape(num_insts * |
| self.in_channels) |
| else: |
| |
| weight_splits[i] = weight_splits[i].reshape( |
| num_insts * 1, -1, 1, 1) |
| bias_splits[i] = bias_splits[i].reshape(num_insts) |
|
|
| return weight_splits, bias_splits |
|
|
| def dynamic_conv_forward(self, features: Tensor, weights: List[Tensor], |
| biases: List[Tensor], num_insts: int) -> Tensor: |
| """dynamic forward, each layer follow a relu.""" |
| n_layers = len(weights) |
| x = features |
| for i, (w, b) in enumerate(zip(weights, biases)): |
| x = F.conv2d(x, w, bias=b, stride=1, padding=0, groups=num_insts) |
| if i < n_layers - 1: |
| x = F.relu(x) |
| return x |
|
|
| def forward(self, x: tuple, positive_infos: InstanceList) -> tuple: |
| """Forward feature from the upstream network to get prototypes and |
| linearly combine the prototypes, using masks coefficients, into |
| instance masks. Finally, crop the instance masks with given bboxes. |
| |
| Args: |
| x (Tuple[Tensor]): Feature from the upstream network, which is |
| a 4D-tensor. |
| positive_infos (List[:obj:``InstanceData``]): Positive information |
| that calculate from detect head. |
| |
| Returns: |
| tuple: Predicted instance segmentation masks |
| """ |
| mask_feats = self.mask_feature_head(x) |
| return multi_apply(self.forward_single, mask_feats, positive_infos) |
|
|
| def forward_single(self, mask_feat: Tensor, |
| positive_info: InstanceData) -> Tensor: |
| """Forward features of a each image.""" |
| pos_param_preds = positive_info.get('param_preds') |
| pos_points = positive_info.get('points') |
| pos_strides = positive_info.get('strides') |
|
|
| num_inst = pos_param_preds.shape[0] |
| mask_feat = mask_feat[None].repeat(num_inst, 1, 1, 1) |
| _, _, H, W = mask_feat.size() |
| if num_inst == 0: |
| return (pos_param_preds.new_zeros((0, 1, H, W)), ) |
|
|
| locations = self.prior_generator.single_level_grid_priors( |
| mask_feat.size()[2:], 0, device=mask_feat.device) |
|
|
| rel_coords = relative_coordinate_maps(locations, pos_points, |
| pos_strides, |
| self.size_of_interest, |
| mask_feat.size()[2:]) |
| mask_head_inputs = torch.cat([rel_coords, mask_feat], dim=1) |
| mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) |
|
|
| weights, biases = self.parse_dynamic_params(pos_param_preds) |
| mask_preds = self.dynamic_conv_forward(mask_head_inputs, weights, |
| biases, num_inst) |
| mask_preds = mask_preds.reshape(-1, H, W) |
| mask_preds = aligned_bilinear( |
| mask_preds.unsqueeze(0), |
| int(self.mask_feat_stride / self.mask_out_stride)).squeeze(0) |
|
|
| return (mask_preds, ) |
|
|
| def loss_by_feat(self, mask_preds: List[Tensor], |
| batch_gt_instances: InstanceList, |
| batch_img_metas: List[dict], positive_infos: InstanceList, |
| **kwargs) -> dict: |
| """Calculate the loss based on the features extracted by the mask head. |
| |
| Args: |
| mask_preds (list[Tensor]): List of predicted masks, each has |
| shape (num_classes, H, W). |
| batch_gt_instances (list[:obj:`InstanceData`]): Batch of |
| gt_instance. It usually includes ``bboxes``, ``masks``, |
| and ``labels`` attributes. |
| batch_img_metas (list[dict]): Meta information of multiple images. |
| positive_infos (List[:obj:``InstanceData``]): Information of |
| positive samples of each image that are assigned in detection |
| head. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of loss components. |
| """ |
| assert positive_infos is not None, \ |
| 'positive_infos should not be None in `CondInstMaskHead`' |
| losses = dict() |
|
|
| loss_mask = 0. |
| num_imgs = len(mask_preds) |
| total_pos = 0 |
|
|
| for idx in range(num_imgs): |
| (mask_pred, pos_mask_targets, num_pos) = \ |
| self._get_targets_single( |
| mask_preds[idx], batch_gt_instances[idx], |
| positive_infos[idx]) |
| |
| total_pos += num_pos |
| if num_pos == 0 or pos_mask_targets is None: |
| loss = mask_pred.new_zeros(1).mean() |
| else: |
| loss = self.loss_mask( |
| mask_pred, pos_mask_targets, |
| reduction_override='none').sum() |
| loss_mask += loss |
|
|
| if total_pos == 0: |
| total_pos += 1 |
| loss_mask = loss_mask / total_pos |
| losses.update(loss_mask=loss_mask) |
| return losses |
|
|
| def _get_targets_single(self, mask_preds: Tensor, |
| gt_instances: InstanceData, |
| positive_info: InstanceData): |
| """Compute targets for predictions of single image. |
| |
| Args: |
| mask_preds (Tensor): Predicted prototypes with shape |
| (num_classes, H, W). |
| gt_instances (:obj:`InstanceData`): Ground truth of instance |
| annotations. It should includes ``bboxes``, ``labels``, |
| and ``masks`` attributes. |
| positive_info (:obj:`InstanceData`): Information of positive |
| samples that are assigned in detection head. It usually |
| contains following keys. |
| |
| - pos_assigned_gt_inds (Tensor): Assigner GT indexes of |
| positive proposals, has shape (num_pos, ) |
| - pos_inds (Tensor): Positive index of image, has |
| shape (num_pos, ). |
| - param_pred (Tensor): Positive param preditions |
| with shape (num_pos, num_params). |
| |
| Returns: |
| tuple: Usually returns a tuple containing learning targets. |
| |
| - mask_preds (Tensor): Positive predicted mask with shape |
| (num_pos, mask_h, mask_w). |
| - pos_mask_targets (Tensor): Positive mask targets with shape |
| (num_pos, mask_h, mask_w). |
| - num_pos (int): Positive numbers. |
| """ |
| gt_bboxes = gt_instances.bboxes |
| device = gt_bboxes.device |
| gt_masks = gt_instances.masks.to_tensor( |
| dtype=torch.bool, device=device).float() |
|
|
| |
| pos_assigned_gt_inds = positive_info.get('pos_assigned_gt_inds') |
| scores = positive_info.get('scores') |
| centernesses = positive_info.get('centernesses') |
| num_pos = pos_assigned_gt_inds.size(0) |
|
|
| if gt_masks.size(0) == 0 or num_pos == 0: |
| return mask_preds, None, 0 |
| |
| |
| |
| if (self.max_masks_to_train != -1) and \ |
| (num_pos > self.max_masks_to_train): |
| perm = torch.randperm(num_pos) |
| select = perm[:self.max_masks_to_train] |
| mask_preds = mask_preds[select] |
| pos_assigned_gt_inds = pos_assigned_gt_inds[select] |
| num_pos = self.max_masks_to_train |
| elif self.topk_masks_per_img != -1: |
| unique_gt_inds = pos_assigned_gt_inds.unique() |
| num_inst_per_gt = max( |
| int(self.topk_masks_per_img / len(unique_gt_inds)), 1) |
|
|
| keep_mask_preds = [] |
| keep_pos_assigned_gt_inds = [] |
| for gt_ind in unique_gt_inds: |
| per_inst_pos_inds = (pos_assigned_gt_inds == gt_ind) |
| mask_preds_per_inst = mask_preds[per_inst_pos_inds] |
| gt_inds_per_inst = pos_assigned_gt_inds[per_inst_pos_inds] |
| if sum(per_inst_pos_inds) > num_inst_per_gt: |
| per_inst_scores = scores[per_inst_pos_inds].sigmoid().max( |
| dim=1)[0] |
| per_inst_centerness = centernesses[ |
| per_inst_pos_inds].sigmoid().reshape(-1, ) |
| select = (per_inst_scores * per_inst_centerness).topk( |
| k=num_inst_per_gt, dim=0)[1] |
| mask_preds_per_inst = mask_preds_per_inst[select] |
| gt_inds_per_inst = gt_inds_per_inst[select] |
| keep_mask_preds.append(mask_preds_per_inst) |
| keep_pos_assigned_gt_inds.append(gt_inds_per_inst) |
| mask_preds = torch.cat(keep_mask_preds) |
| pos_assigned_gt_inds = torch.cat(keep_pos_assigned_gt_inds) |
| num_pos = pos_assigned_gt_inds.size(0) |
|
|
| |
| start = int(self.mask_out_stride // 2) |
| gt_masks = gt_masks[:, start::self.mask_out_stride, |
| start::self.mask_out_stride] |
| gt_masks = gt_masks.gt(0.5).float() |
| pos_mask_targets = gt_masks[pos_assigned_gt_inds] |
|
|
| return (mask_preds, pos_mask_targets, num_pos) |
|
|
| def predict_by_feat(self, |
| mask_preds: List[Tensor], |
| results_list: InstanceList, |
| batch_img_metas: List[dict], |
| rescale: bool = True, |
| **kwargs) -> InstanceList: |
| """Transform a batch of output features extracted from the head into |
| mask results. |
| |
| Args: |
| mask_preds (list[Tensor]): Predicted prototypes with shape |
| (num_classes, H, W). |
| results_list (List[:obj:``InstanceData``]): BBoxHead results. |
| batch_img_metas (list[dict]): Meta information of all images. |
| rescale (bool, optional): Whether to rescale the results. |
| Defaults to False. |
| |
| Returns: |
| list[:obj:`InstanceData`]: Processed results of multiple |
| images.Each :obj:`InstanceData` usually contains |
| following keys. |
| |
| - scores (Tensor): Classification scores, has shape |
| (num_instance,). |
| - labels (Tensor): Has shape (num_instances,). |
| - masks (Tensor): Processed mask results, has |
| shape (num_instances, h, w). |
| """ |
| assert len(mask_preds) == len(results_list) == len(batch_img_metas) |
|
|
| for img_id in range(len(batch_img_metas)): |
| img_meta = batch_img_metas[img_id] |
| results = results_list[img_id] |
| bboxes = results.bboxes |
| mask_pred = mask_preds[img_id] |
| if bboxes.shape[0] == 0 or mask_pred.shape[0] == 0: |
| results_list[img_id] = empty_instances( |
| [img_meta], |
| bboxes.device, |
| task_type='mask', |
| instance_results=[results])[0] |
| else: |
| im_mask = self._predict_by_feat_single( |
| mask_preds=mask_pred, |
| bboxes=bboxes, |
| img_meta=img_meta, |
| rescale=rescale) |
| results.masks = im_mask |
| return results_list |
|
|
| def _predict_by_feat_single(self, |
| mask_preds: Tensor, |
| bboxes: Tensor, |
| img_meta: dict, |
| rescale: bool, |
| cfg: OptConfigType = None): |
| """Transform a single image's features extracted from the head into |
| mask results. |
| |
| Args: |
| mask_preds (Tensor): Predicted prototypes, has shape [H, W, N]. |
| img_meta (dict): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| rescale (bool): If rescale is False, then returned masks will |
| fit the scale of imgs[0]. |
| cfg (dict, optional): Config used in test phase. |
| Defaults to None. |
| |
| Returns: |
| :obj:`InstanceData`: Processed results of single image. |
| it usually contains following keys. |
| |
| - scores (Tensor): Classification scores, has shape |
| (num_instance,). |
| - labels (Tensor): Has shape (num_instances,). |
| - masks (Tensor): Processed mask results, has |
| shape (num_instances, h, w). |
| """ |
| cfg = self.test_cfg if cfg is None else cfg |
| scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( |
| (1, 2)) |
| img_h, img_w = img_meta['img_shape'][:2] |
| ori_h, ori_w = img_meta['ori_shape'][:2] |
|
|
| mask_preds = mask_preds.sigmoid().unsqueeze(0) |
| mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride) |
| mask_preds = mask_preds[:, :, :img_h, :img_w] |
| if rescale: |
| scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( |
| (1, 2)) |
| bboxes /= scale_factor |
|
|
| masks = F.interpolate( |
| mask_preds, (ori_h, ori_w), |
| mode='bilinear', |
| align_corners=False).squeeze(0) > cfg.mask_thr |
| else: |
| masks = mask_preds.squeeze(0) > cfg.mask_thr |
|
|
| return masks |
|
|