| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from mmcv.cnn import ConvModule, xavier_init |
| from mmcv.runner import force_fp32 |
|
|
| from mmdet.core import build_sampler, fast_nms, images_to_levels, multi_apply |
| from ..builder import HEADS, build_loss |
| from .anchor_head import AnchorHead |
|
|
|
|
| @HEADS.register_module() |
| class YOLACTHead(AnchorHead): |
| """YOLACT box head used in https://arxiv.org/abs/1904.02689. |
| |
| Note that YOLACT head is a light version of RetinaNet head. |
| Four differences are described as follows: |
| |
| 1. YOLACT box head has three-times fewer anchors. |
| 2. YOLACT box head shares the convs for box and cls branches. |
| 3. YOLACT box head uses OHEM instead of Focal loss. |
| 4. YOLACT box head predicts a set of mask coefficients for each box. |
| |
| Args: |
| num_classes (int): Number of categories excluding the background |
| category. |
| in_channels (int): Number of channels in the input feature map. |
| anchor_generator (dict): Config dict for anchor generator |
| loss_cls (dict): Config of classification loss. |
| loss_bbox (dict): Config of localization loss. |
| num_head_convs (int): Number of the conv layers shared by |
| box and cls branches. |
| num_protos (int): Number of the mask coefficients. |
| use_ohem (bool): If true, ``loss_single_OHEM`` will be used for |
| cls loss calculation. If false, ``loss_single`` will be used. |
| conv_cfg (dict): Dictionary to construct and config conv layer. |
| norm_cfg (dict): Dictionary to construct and config norm layer. |
| """ |
|
|
| def __init__(self, |
| num_classes, |
| in_channels, |
| anchor_generator=dict( |
| type='AnchorGenerator', |
| octave_base_scale=3, |
| scales_per_octave=1, |
| ratios=[0.5, 1.0, 2.0], |
| strides=[8, 16, 32, 64, 128]), |
| loss_cls=dict( |
| type='CrossEntropyLoss', |
| use_sigmoid=False, |
| reduction='none', |
| loss_weight=1.0), |
| loss_bbox=dict( |
| type='SmoothL1Loss', beta=1.0, loss_weight=1.5), |
| num_head_convs=1, |
| num_protos=32, |
| use_ohem=True, |
| conv_cfg=None, |
| norm_cfg=None, |
| **kwargs): |
| self.num_head_convs = num_head_convs |
| self.num_protos = num_protos |
| self.use_ohem = use_ohem |
| self.conv_cfg = conv_cfg |
| self.norm_cfg = norm_cfg |
| super(YOLACTHead, self).__init__( |
| num_classes, |
| in_channels, |
| loss_cls=loss_cls, |
| loss_bbox=loss_bbox, |
| anchor_generator=anchor_generator, |
| **kwargs) |
| if self.use_ohem: |
| sampler_cfg = dict(type='PseudoSampler') |
| self.sampler = build_sampler(sampler_cfg, context=self) |
| self.sampling = False |
|
|
| def _init_layers(self): |
| """Initialize layers of the head.""" |
| self.relu = nn.ReLU(inplace=True) |
| self.head_convs = nn.ModuleList() |
| for i in range(self.num_head_convs): |
| chn = self.in_channels if i == 0 else self.feat_channels |
| self.head_convs.append( |
| ConvModule( |
| chn, |
| self.feat_channels, |
| 3, |
| stride=1, |
| padding=1, |
| conv_cfg=self.conv_cfg, |
| norm_cfg=self.norm_cfg)) |
| self.conv_cls = nn.Conv2d( |
| self.feat_channels, |
| self.num_anchors * self.cls_out_channels, |
| 3, |
| padding=1) |
| self.conv_reg = nn.Conv2d( |
| self.feat_channels, self.num_anchors * 4, 3, padding=1) |
| self.conv_coeff = nn.Conv2d( |
| self.feat_channels, |
| self.num_anchors * self.num_protos, |
| 3, |
| padding=1) |
|
|
| def init_weights(self): |
| """Initialize weights of the head.""" |
| for m in self.head_convs: |
| xavier_init(m.conv, distribution='uniform', bias=0) |
| xavier_init(self.conv_cls, distribution='uniform', bias=0) |
| xavier_init(self.conv_reg, distribution='uniform', bias=0) |
| xavier_init(self.conv_coeff, distribution='uniform', bias=0) |
|
|
| def forward_single(self, x): |
| """Forward feature of a single scale level. |
| |
| Args: |
| x (Tensor): Features of a single scale level. |
| |
| Returns: |
| tuple: |
| cls_score (Tensor): Cls scores for a single scale level \ |
| the channels number is num_anchors * num_classes. |
| bbox_pred (Tensor): Box energies / deltas for a single scale \ |
| level, the channels number is num_anchors * 4. |
| coeff_pred (Tensor): Mask coefficients for a single scale \ |
| level, the channels number is num_anchors * num_protos. |
| """ |
| for head_conv in self.head_convs: |
| x = head_conv(x) |
| cls_score = self.conv_cls(x) |
| bbox_pred = self.conv_reg(x) |
| coeff_pred = self.conv_coeff(x).tanh() |
| return cls_score, bbox_pred, coeff_pred |
|
|
| @force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
| def loss(self, |
| cls_scores, |
| bbox_preds, |
| gt_bboxes, |
| gt_labels, |
| img_metas, |
| gt_bboxes_ignore=None): |
| """A combination of the func:``AnchorHead.loss`` and |
| func:``SSDHead.loss``. |
| |
| When ``self.use_ohem == True``, it functions like ``SSDHead.loss``, |
| otherwise, it follows ``AnchorHead.loss``. Besides, it additionally |
| returns ``sampling_results``. |
| |
| Args: |
| cls_scores (list[Tensor]): Box scores for each scale level |
| Has shape (N, num_anchors * num_classes, H, W) |
| bbox_preds (list[Tensor]): Box energies / deltas for each scale |
| level with shape (N, num_anchors * 4, H, W) |
| gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
| shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
| gt_labels (list[Tensor]): Class indices corresponding to each box |
| img_metas (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| gt_bboxes_ignore (None | list[Tensor]): Specify which bounding |
| boxes can be ignored when computing the loss. Default: None |
| |
| Returns: |
| tuple: |
| dict[str, Tensor]: A dictionary of loss components. |
| List[:obj:``SamplingResult``]: Sampler results for each image. |
| """ |
| featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
| assert len(featmap_sizes) == self.anchor_generator.num_levels |
|
|
| device = cls_scores[0].device |
|
|
| anchor_list, valid_flag_list = self.get_anchors( |
| featmap_sizes, img_metas, device=device) |
| label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 |
| cls_reg_targets = self.get_targets( |
| anchor_list, |
| valid_flag_list, |
| gt_bboxes, |
| img_metas, |
| gt_bboxes_ignore_list=gt_bboxes_ignore, |
| gt_labels_list=gt_labels, |
| label_channels=label_channels, |
| unmap_outputs=not self.use_ohem, |
| return_sampling_results=True) |
| if cls_reg_targets is None: |
| return None |
| (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
| num_total_pos, num_total_neg, sampling_results) = cls_reg_targets |
|
|
| if self.use_ohem: |
| num_images = len(img_metas) |
| all_cls_scores = torch.cat([ |
| s.permute(0, 2, 3, 1).reshape( |
| num_images, -1, self.cls_out_channels) for s in cls_scores |
| ], 1) |
| all_labels = torch.cat(labels_list, -1).view(num_images, -1) |
| all_label_weights = torch.cat(label_weights_list, |
| -1).view(num_images, -1) |
| all_bbox_preds = torch.cat([ |
| b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) |
| for b in bbox_preds |
| ], -2) |
| all_bbox_targets = torch.cat(bbox_targets_list, |
| -2).view(num_images, -1, 4) |
| all_bbox_weights = torch.cat(bbox_weights_list, |
| -2).view(num_images, -1, 4) |
|
|
| |
| all_anchors = [] |
| for i in range(num_images): |
| all_anchors.append(torch.cat(anchor_list[i])) |
|
|
| |
| assert torch.isfinite(all_cls_scores).all().item(), \ |
| 'classification scores become infinite or NaN!' |
| assert torch.isfinite(all_bbox_preds).all().item(), \ |
| 'bbox predications become infinite or NaN!' |
|
|
| losses_cls, losses_bbox = multi_apply( |
| self.loss_single_OHEM, |
| all_cls_scores, |
| all_bbox_preds, |
| all_anchors, |
| all_labels, |
| all_label_weights, |
| all_bbox_targets, |
| all_bbox_weights, |
| num_total_samples=num_total_pos) |
| else: |
| num_total_samples = ( |
| num_total_pos + |
| num_total_neg if self.sampling else num_total_pos) |
|
|
| |
| num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] |
| |
| concat_anchor_list = [] |
| for i in range(len(anchor_list)): |
| concat_anchor_list.append(torch.cat(anchor_list[i])) |
| all_anchor_list = images_to_levels(concat_anchor_list, |
| num_level_anchors) |
| losses_cls, losses_bbox = multi_apply( |
| self.loss_single, |
| cls_scores, |
| bbox_preds, |
| all_anchor_list, |
| labels_list, |
| label_weights_list, |
| bbox_targets_list, |
| bbox_weights_list, |
| num_total_samples=num_total_samples) |
|
|
| return dict( |
| loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results |
|
|
| def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels, |
| label_weights, bbox_targets, bbox_weights, |
| num_total_samples): |
| """"See func:``SSDHead.loss``.""" |
| loss_cls_all = self.loss_cls(cls_score, labels, label_weights) |
|
|
| |
| pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero( |
| as_tuple=False).reshape(-1) |
| neg_inds = (labels == self.num_classes).nonzero( |
| as_tuple=False).view(-1) |
|
|
| num_pos_samples = pos_inds.size(0) |
| if num_pos_samples == 0: |
| num_neg_samples = neg_inds.size(0) |
| else: |
| num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples |
| if num_neg_samples > neg_inds.size(0): |
| num_neg_samples = neg_inds.size(0) |
| topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) |
| loss_cls_pos = loss_cls_all[pos_inds].sum() |
| loss_cls_neg = topk_loss_cls_neg.sum() |
| loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples |
| if self.reg_decoded_bbox: |
| |
| |
| |
| bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) |
| loss_bbox = self.loss_bbox( |
| bbox_pred, |
| bbox_targets, |
| bbox_weights, |
| avg_factor=num_total_samples) |
| return loss_cls[None], loss_bbox |
|
|
| @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'coeff_preds')) |
| def get_bboxes(self, |
| cls_scores, |
| bbox_preds, |
| coeff_preds, |
| img_metas, |
| cfg=None, |
| rescale=False): |
| """"Similiar to func:``AnchorHead.get_bboxes``, but additionally |
| processes coeff_preds. |
| |
| Args: |
| cls_scores (list[Tensor]): Box scores for each scale level |
| with shape (N, num_anchors * num_classes, H, W) |
| bbox_preds (list[Tensor]): Box energies / deltas for each scale |
| level with shape (N, num_anchors * 4, H, W) |
| coeff_preds (list[Tensor]): Mask coefficients for each scale |
| level with shape (N, num_anchors * num_protos, H, W) |
| img_metas (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| cfg (mmcv.Config | None): Test / postprocessing configuration, |
| if None, test_cfg would be used |
| rescale (bool): If True, return boxes in original image space. |
| Default: False. |
| |
| Returns: |
| list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is |
| a 3-tuple. The first item is an (n, 5) tensor, where the |
| first 4 columns are bounding box positions |
| (tl_x, tl_y, br_x, br_y) and the 5-th column is a score |
| between 0 and 1. The second item is an (n,) tensor where each |
| item is the predicted class label of the corresponding box. |
| The third item is an (n, num_protos) tensor where each item |
| is the predicted mask coefficients of instance inside the |
| corresponding box. |
| """ |
| assert len(cls_scores) == len(bbox_preds) |
| num_levels = len(cls_scores) |
|
|
| device = cls_scores[0].device |
| featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] |
| mlvl_anchors = self.anchor_generator.grid_anchors( |
| featmap_sizes, device=device) |
|
|
| det_bboxes = [] |
| det_labels = [] |
| det_coeffs = [] |
| for img_id in range(len(img_metas)): |
| cls_score_list = [ |
| cls_scores[i][img_id].detach() for i in range(num_levels) |
| ] |
| bbox_pred_list = [ |
| bbox_preds[i][img_id].detach() for i in range(num_levels) |
| ] |
| coeff_pred_list = [ |
| coeff_preds[i][img_id].detach() for i in range(num_levels) |
| ] |
| img_shape = img_metas[img_id]['img_shape'] |
| scale_factor = img_metas[img_id]['scale_factor'] |
| bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list, |
| coeff_pred_list, mlvl_anchors, |
| img_shape, scale_factor, cfg, |
| rescale) |
| det_bboxes.append(bbox_res[0]) |
| det_labels.append(bbox_res[1]) |
| det_coeffs.append(bbox_res[2]) |
| return det_bboxes, det_labels, det_coeffs |
|
|
| def _get_bboxes_single(self, |
| cls_score_list, |
| bbox_pred_list, |
| coeff_preds_list, |
| mlvl_anchors, |
| img_shape, |
| scale_factor, |
| cfg, |
| rescale=False): |
| """"Similiar to func:``AnchorHead._get_bboxes_single``, but |
| additionally processes coeff_preds_list and uses fast NMS instead of |
| traditional NMS. |
| |
| Args: |
| cls_score_list (list[Tensor]): Box scores for a single scale level |
| Has shape (num_anchors * num_classes, H, W). |
| bbox_pred_list (list[Tensor]): Box energies / deltas for a single |
| scale level with shape (num_anchors * 4, H, W). |
| coeff_preds_list (list[Tensor]): Mask coefficients for a single |
| scale level with shape (num_anchors * num_protos, H, W). |
| mlvl_anchors (list[Tensor]): Box reference for a single scale level |
| with shape (num_total_anchors, 4). |
| img_shape (tuple[int]): Shape of the input image, |
| (height, width, 3). |
| scale_factor (ndarray): Scale factor of the image arange as |
| (w_scale, h_scale, w_scale, h_scale). |
| cfg (mmcv.Config): Test / postprocessing configuration, |
| if None, test_cfg would be used. |
| rescale (bool): If True, return boxes in original image space. |
| |
| Returns: |
| tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor, |
| where the first 4 columns are bounding box positions |
| (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between |
| 0 and 1. The second item is an (n,) tensor where each item is |
| the predicted class label of the corresponding box. The third |
| item is an (n, num_protos) tensor where each item is the |
| predicted mask coefficients of instance inside the |
| corresponding box. |
| """ |
| cfg = self.test_cfg if cfg is None else cfg |
| assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) |
| mlvl_bboxes = [] |
| mlvl_scores = [] |
| mlvl_coeffs = [] |
| for cls_score, bbox_pred, coeff_pred, anchors in \ |
| zip(cls_score_list, bbox_pred_list, |
| coeff_preds_list, mlvl_anchors): |
| assert cls_score.size()[-2:] == bbox_pred.size()[-2:] |
| cls_score = cls_score.permute(1, 2, |
| 0).reshape(-1, self.cls_out_channels) |
| if self.use_sigmoid_cls: |
| scores = cls_score.sigmoid() |
| else: |
| scores = cls_score.softmax(-1) |
| bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) |
| coeff_pred = coeff_pred.permute(1, 2, |
| 0).reshape(-1, self.num_protos) |
| nms_pre = cfg.get('nms_pre', -1) |
| if nms_pre > 0 and scores.shape[0] > nms_pre: |
| |
| if self.use_sigmoid_cls: |
| max_scores, _ = scores.max(dim=1) |
| else: |
| |
| |
| |
| max_scores, _ = scores[:, :-1].max(dim=1) |
| _, topk_inds = max_scores.topk(nms_pre) |
| anchors = anchors[topk_inds, :] |
| bbox_pred = bbox_pred[topk_inds, :] |
| scores = scores[topk_inds, :] |
| coeff_pred = coeff_pred[topk_inds, :] |
| bboxes = self.bbox_coder.decode( |
| anchors, bbox_pred, max_shape=img_shape) |
| mlvl_bboxes.append(bboxes) |
| mlvl_scores.append(scores) |
| mlvl_coeffs.append(coeff_pred) |
| mlvl_bboxes = torch.cat(mlvl_bboxes) |
| if rescale: |
| mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) |
| mlvl_scores = torch.cat(mlvl_scores) |
| mlvl_coeffs = torch.cat(mlvl_coeffs) |
| if self.use_sigmoid_cls: |
| |
| |
| |
| padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) |
| mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) |
| det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores, |
| mlvl_coeffs, |
| cfg.score_thr, |
| cfg.iou_thr, cfg.top_k, |
| cfg.max_per_img) |
| return det_bboxes, det_labels, det_coeffs |
|
|
|
|
| @HEADS.register_module() |
| class YOLACTSegmHead(nn.Module): |
| """YOLACT segmentation head used in https://arxiv.org/abs/1904.02689. |
| |
| Apply a semantic segmentation loss on feature space using layers that are |
| only evaluated during training to increase performance with no speed |
| penalty. |
| |
| Args: |
| in_channels (int): Number of channels in the input feature map. |
| num_classes (int): Number of categories excluding the background |
| category. |
| loss_segm (dict): Config of semantic segmentation loss. |
| """ |
|
|
| def __init__(self, |
| num_classes, |
| in_channels=256, |
| loss_segm=dict( |
| type='CrossEntropyLoss', |
| use_sigmoid=True, |
| loss_weight=1.0)): |
| super(YOLACTSegmHead, self).__init__() |
| self.in_channels = in_channels |
| self.num_classes = num_classes |
| self.loss_segm = build_loss(loss_segm) |
| self._init_layers() |
| self.fp16_enabled = False |
|
|
| def _init_layers(self): |
| """Initialize layers of the head.""" |
| self.segm_conv = nn.Conv2d( |
| self.in_channels, self.num_classes, kernel_size=1) |
|
|
| def init_weights(self): |
| """Initialize weights of the head.""" |
| xavier_init(self.segm_conv, distribution='uniform') |
|
|
| def forward(self, x): |
| """Forward feature from the upstream network. |
| |
| Args: |
| x (Tensor): Feature from the upstream network, which is |
| a 4D-tensor. |
| |
| Returns: |
| Tensor: Predicted semantic segmentation map with shape |
| (N, num_classes, H, W). |
| """ |
| return self.segm_conv(x) |
|
|
| @force_fp32(apply_to=('segm_pred', )) |
| def loss(self, segm_pred, gt_masks, gt_labels): |
| """Compute loss of the head. |
| |
| Args: |
| segm_pred (list[Tensor]): Predicted semantic segmentation map |
| with shape (N, num_classes, H, W). |
| gt_masks (list[Tensor]): Ground truth masks for each image with |
| the same shape of the input image. |
| gt_labels (list[Tensor]): Class indices corresponding to each box. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of loss components. |
| """ |
| loss_segm = [] |
| num_imgs, num_classes, mask_h, mask_w = segm_pred.size() |
| for idx in range(num_imgs): |
| cur_segm_pred = segm_pred[idx] |
| cur_gt_masks = gt_masks[idx].float() |
| cur_gt_labels = gt_labels[idx] |
| segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks, |
| cur_gt_labels) |
| if segm_targets is None: |
| loss = self.loss_segm(cur_segm_pred, |
| torch.zeros_like(cur_segm_pred), |
| torch.zeros_like(cur_segm_pred)) |
| else: |
| loss = self.loss_segm( |
| cur_segm_pred, |
| segm_targets, |
| avg_factor=num_imgs * mask_h * mask_w) |
| loss_segm.append(loss) |
| return dict(loss_segm=loss_segm) |
|
|
| def get_targets(self, segm_pred, gt_masks, gt_labels): |
| """Compute semantic segmentation targets for each image. |
| |
| Args: |
| segm_pred (Tensor): Predicted semantic segmentation map |
| with shape (num_classes, H, W). |
| gt_masks (Tensor): Ground truth masks for each image with |
| the same shape of the input image. |
| gt_labels (Tensor): Class indices corresponding to each box. |
| |
| Returns: |
| Tensor: Semantic segmentation targets with shape |
| (num_classes, H, W). |
| """ |
| if gt_masks.size(0) == 0: |
| return None |
| num_classes, mask_h, mask_w = segm_pred.size() |
| with torch.no_grad(): |
| downsampled_masks = F.interpolate( |
| gt_masks.unsqueeze(0), (mask_h, mask_w), |
| mode='bilinear', |
| align_corners=False).squeeze(0) |
| downsampled_masks = downsampled_masks.gt(0.5).float() |
| segm_targets = torch.zeros_like(segm_pred, requires_grad=False) |
| for obj_idx in range(downsampled_masks.size(0)): |
| segm_targets[gt_labels[obj_idx] - 1] = torch.max( |
| segm_targets[gt_labels[obj_idx] - 1], |
| downsampled_masks[obj_idx]) |
| return segm_targets |
|
|
|
|
| @HEADS.register_module() |
| class YOLACTProtonet(nn.Module): |
| """YOLACT mask head used in https://arxiv.org/abs/1904.02689. |
| |
| This head outputs the mask prototypes for YOLACT. |
| |
| Args: |
| in_channels (int): Number of channels in the input feature map. |
| proto_channels (tuple[int]): Output channels of protonet convs. |
| proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs. |
| include_last_relu (Bool): If keep the last relu of protonet. |
| num_protos (int): Number of prototypes. |
| num_classes (int): Number of categories excluding the background |
| category. |
| loss_mask_weight (float): Reweight the mask loss by this factor. |
| max_masks_to_train (int): Maximum number of masks to train for |
| each image. |
| """ |
|
|
| def __init__(self, |
| num_classes, |
| in_channels=256, |
| proto_channels=(256, 256, 256, None, 256, 32), |
| proto_kernel_sizes=(3, 3, 3, -2, 3, 1), |
| include_last_relu=True, |
| num_protos=32, |
| loss_mask_weight=1.0, |
| max_masks_to_train=100): |
| super(YOLACTProtonet, self).__init__() |
| self.in_channels = in_channels |
| self.proto_channels = proto_channels |
| self.proto_kernel_sizes = proto_kernel_sizes |
| self.include_last_relu = include_last_relu |
| self.protonet = self._init_layers() |
|
|
| self.loss_mask_weight = loss_mask_weight |
| self.num_protos = num_protos |
| self.num_classes = num_classes |
| self.max_masks_to_train = max_masks_to_train |
| self.fp16_enabled = False |
|
|
| def _init_layers(self): |
| """A helper function to take a config setting and turn it into a |
| network.""" |
| |
| |
| |
| |
| in_channels = self.in_channels |
| protonets = nn.ModuleList() |
| for num_channels, kernel_size in zip(self.proto_channels, |
| self.proto_kernel_sizes): |
| if kernel_size > 0: |
| layer = nn.Conv2d( |
| in_channels, |
| num_channels, |
| kernel_size, |
| padding=kernel_size // 2) |
| else: |
| if num_channels is None: |
| layer = InterpolateModule( |
| scale_factor=-kernel_size, |
| mode='bilinear', |
| align_corners=False) |
| else: |
| layer = nn.ConvTranspose2d( |
| in_channels, |
| num_channels, |
| -kernel_size, |
| padding=kernel_size // 2) |
| protonets.append(layer) |
| protonets.append(nn.ReLU(inplace=True)) |
| in_channels = num_channels if num_channels is not None \ |
| else in_channels |
| if not self.include_last_relu: |
| protonets = protonets[:-1] |
| return nn.Sequential(*protonets) |
|
|
| def init_weights(self): |
| """Initialize weights of the head.""" |
| for m in self.protonet: |
| if isinstance(m, nn.Conv2d): |
| xavier_init(m, distribution='uniform') |
|
|
| def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None): |
| """Forward feature from the upstream network to get prototypes and |
| linearly combine the prototypes, using masks coefficients, into |
| instance masks. Finally, crop the instance masks with given bboxes. |
| |
| Args: |
| x (Tensor): Feature from the upstream network, which is |
| a 4D-tensor. |
| coeff_pred (list[Tensor]): Mask coefficients for each scale |
| level with shape (N, num_anchors * num_protos, H, W). |
| bboxes (list[Tensor]): Box used for cropping with shape |
| (N, num_anchors * 4, H, W). During training, they are |
| ground truth boxes. During testing, they are predicted |
| boxes. |
| img_meta (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| sampling_results (List[:obj:``SamplingResult``]): Sampler results |
| for each image. |
| |
| Returns: |
| list[Tensor]: Predicted instance segmentation masks. |
| """ |
| prototypes = self.protonet(x) |
| prototypes = prototypes.permute(0, 2, 3, 1).contiguous() |
|
|
| num_imgs = x.size(0) |
| |
| if self.training: |
| coeff_pred_list = [] |
| for coeff_pred_per_level in coeff_pred: |
| coeff_pred_per_level = \ |
| coeff_pred_per_level.permute(0, 2, 3, 1)\ |
| .reshape(num_imgs, -1, self.num_protos) |
| coeff_pred_list.append(coeff_pred_per_level) |
| coeff_pred = torch.cat(coeff_pred_list, dim=1) |
|
|
| mask_pred_list = [] |
| for idx in range(num_imgs): |
| cur_prototypes = prototypes[idx] |
| cur_coeff_pred = coeff_pred[idx] |
| cur_bboxes = bboxes[idx] |
| cur_img_meta = img_meta[idx] |
|
|
| |
| if not self.training: |
| bboxes_for_cropping = cur_bboxes |
| else: |
| cur_sampling_results = sampling_results[idx] |
| pos_assigned_gt_inds = \ |
| cur_sampling_results.pos_assigned_gt_inds |
| bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone() |
| pos_inds = cur_sampling_results.pos_inds |
| cur_coeff_pred = cur_coeff_pred[pos_inds] |
|
|
| |
| mask_pred = cur_prototypes @ cur_coeff_pred.t() |
| mask_pred = torch.sigmoid(mask_pred) |
|
|
| h, w = cur_img_meta['img_shape'][:2] |
| bboxes_for_cropping[:, 0] /= w |
| bboxes_for_cropping[:, 1] /= h |
| bboxes_for_cropping[:, 2] /= w |
| bboxes_for_cropping[:, 3] /= h |
|
|
| mask_pred = self.crop(mask_pred, bboxes_for_cropping) |
| mask_pred = mask_pred.permute(2, 0, 1).contiguous() |
| mask_pred_list.append(mask_pred) |
| return mask_pred_list |
|
|
| @force_fp32(apply_to=('mask_pred', )) |
| def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results): |
| """Compute loss of the head. |
| |
| Args: |
| mask_pred (list[Tensor]): Predicted prototypes with shape |
| (num_classes, H, W). |
| gt_masks (list[Tensor]): Ground truth masks for each image with |
| the same shape of the input image. |
| gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
| shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
| img_meta (list[dict]): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| sampling_results (List[:obj:``SamplingResult``]): Sampler results |
| for each image. |
| |
| Returns: |
| dict[str, Tensor]: A dictionary of loss components. |
| """ |
| loss_mask = [] |
| num_imgs = len(mask_pred) |
| total_pos = 0 |
| for idx in range(num_imgs): |
| cur_mask_pred = mask_pred[idx] |
| cur_gt_masks = gt_masks[idx].float() |
| cur_gt_bboxes = gt_bboxes[idx] |
| cur_img_meta = img_meta[idx] |
| cur_sampling_results = sampling_results[idx] |
|
|
| pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds |
| num_pos = pos_assigned_gt_inds.size(0) |
| |
| |
| |
| if num_pos > self.max_masks_to_train: |
| perm = torch.randperm(num_pos) |
| select = perm[:self.max_masks_to_train] |
| cur_mask_pred = cur_mask_pred[select] |
| pos_assigned_gt_inds = pos_assigned_gt_inds[select] |
| num_pos = self.max_masks_to_train |
| total_pos += num_pos |
|
|
| gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds] |
|
|
| mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks, |
| pos_assigned_gt_inds) |
| if num_pos == 0: |
| loss = cur_mask_pred.sum() * 0. |
| elif mask_targets is None: |
| loss = F.binary_cross_entropy(cur_mask_pred, |
| torch.zeros_like(cur_mask_pred), |
| torch.zeros_like(cur_mask_pred)) |
| else: |
| cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1) |
| loss = F.binary_cross_entropy( |
| cur_mask_pred, mask_targets, |
| reduction='none') * self.loss_mask_weight |
|
|
| h, w = cur_img_meta['img_shape'][:2] |
| gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] - |
| gt_bboxes_for_reweight[:, 0]) / w |
| gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] - |
| gt_bboxes_for_reweight[:, 1]) / h |
| loss = loss.mean(dim=(1, |
| 2)) / gt_bboxes_width / gt_bboxes_height |
| loss = torch.sum(loss) |
| loss_mask.append(loss) |
|
|
| if total_pos == 0: |
| total_pos += 1 |
| loss_mask = [x / total_pos for x in loss_mask] |
|
|
| return dict(loss_mask=loss_mask) |
|
|
| def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds): |
| """Compute instance segmentation targets for each image. |
| |
| Args: |
| mask_pred (Tensor): Predicted prototypes with shape |
| (num_classes, H, W). |
| gt_masks (Tensor): Ground truth masks for each image with |
| the same shape of the input image. |
| pos_assigned_gt_inds (Tensor): GT indices of the corresponding |
| positive samples. |
| Returns: |
| Tensor: Instance segmentation targets with shape |
| (num_instances, H, W). |
| """ |
| if gt_masks.size(0) == 0: |
| return None |
| mask_h, mask_w = mask_pred.shape[-2:] |
| gt_masks = F.interpolate( |
| gt_masks.unsqueeze(0), (mask_h, mask_w), |
| mode='bilinear', |
| align_corners=False).squeeze(0) |
| gt_masks = gt_masks.gt(0.5).float() |
| mask_targets = gt_masks[pos_assigned_gt_inds] |
| return mask_targets |
|
|
| def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale): |
| """Resize, binarize, and format the instance mask predictions. |
| |
| Args: |
| mask_pred (Tensor): shape (N, H, W). |
| label_pred (Tensor): shape (N, ). |
| img_meta (dict): Meta information of each image, e.g., |
| image size, scaling factor, etc. |
| rescale (bool): If rescale is False, then returned masks will |
| fit the scale of imgs[0]. |
| Returns: |
| list[ndarray]: Mask predictions grouped by their predicted classes. |
| """ |
| ori_shape = img_meta['ori_shape'] |
| scale_factor = img_meta['scale_factor'] |
| if rescale: |
| img_h, img_w = ori_shape[:2] |
| else: |
| img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32) |
| img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32) |
|
|
| cls_segms = [[] for _ in range(self.num_classes)] |
| if mask_pred.size(0) == 0: |
| return cls_segms |
|
|
| mask_pred = F.interpolate( |
| mask_pred.unsqueeze(0), (img_h, img_w), |
| mode='bilinear', |
| align_corners=False).squeeze(0) > 0.5 |
| mask_pred = mask_pred.cpu().numpy().astype(np.uint8) |
|
|
| for m, l in zip(mask_pred, label_pred): |
| cls_segms[l].append(m) |
| return cls_segms |
|
|
| def crop(self, masks, boxes, padding=1): |
| """Crop predicted masks by zeroing out everything not in the predicted |
| bbox. |
| |
| Args: |
| masks (Tensor): shape [H, W, N]. |
| boxes (Tensor): bbox coords in relative point form with |
| shape [N, 4]. |
| |
| Return: |
| Tensor: The cropped masks. |
| """ |
| h, w, n = masks.size() |
| x1, x2 = self.sanitize_coordinates( |
| boxes[:, 0], boxes[:, 2], w, padding, cast=False) |
| y1, y2 = self.sanitize_coordinates( |
| boxes[:, 1], boxes[:, 3], h, padding, cast=False) |
|
|
| rows = torch.arange( |
| w, device=masks.device, dtype=x1.dtype).view(1, -1, |
| 1).expand(h, w, n) |
| cols = torch.arange( |
| h, device=masks.device, dtype=x1.dtype).view(-1, 1, |
| 1).expand(h, w, n) |
|
|
| masks_left = rows >= x1.view(1, 1, -1) |
| masks_right = rows < x2.view(1, 1, -1) |
| masks_up = cols >= y1.view(1, 1, -1) |
| masks_down = cols < y2.view(1, 1, -1) |
|
|
| crop_mask = masks_left * masks_right * masks_up * masks_down |
|
|
| return masks * crop_mask.float() |
|
|
| def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True): |
| """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, |
| and x2 <= image_size. Also converts from relative to absolute |
| coordinates and casts the results to long tensors. |
| |
| Warning: this does things in-place behind the scenes so |
| copy if necessary. |
| |
| Args: |
| _x1 (Tensor): shape (N, ). |
| _x2 (Tensor): shape (N, ). |
| img_size (int): Size of the input image. |
| padding (int): x1 >= padding, x2 <= image_size-padding. |
| cast (bool): If cast is false, the result won't be cast to longs. |
| |
| Returns: |
| tuple: |
| x1 (Tensor): Sanitized _x1. |
| x2 (Tensor): Sanitized _x2. |
| """ |
| x1 = x1 * img_size |
| x2 = x2 * img_size |
| if cast: |
| x1 = x1.long() |
| x2 = x2.long() |
| x1 = torch.min(x1, x2) |
| x2 = torch.max(x1, x2) |
| x1 = torch.clamp(x1 - padding, min=0) |
| x2 = torch.clamp(x2 + padding, max=img_size) |
| return x1, x2 |
|
|
|
|
| class InterpolateModule(nn.Module): |
| """This is a module version of F.interpolate. |
| |
| Any arguments you give it just get passed along for the ride. |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__() |
|
|
| self.args = args |
| self.kwargs = kwargs |
|
|
| def forward(self, x): |
| """Forward features from the upstream network.""" |
| return F.interpolate(x, *self.args, **self.kwargs) |
|
|