| |
| import numpy as np |
| from typing import Any, List |
|
|
| from detectron2.modeling import ROI_MASK_HEAD_REGISTRY |
| from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, mask_rcnn_inference |
| from detectron2.projects.point_rend import ImplicitPointRendMaskHead |
| from detectron2.projects.point_rend.point_features import point_sample |
| from detectron2.projects.point_rend.point_head import roi_mask_point_loss |
| from detectron2.structures import Instances |
|
|
| from .point_utils import get_point_coords_from_point_annotation |
|
|
| __all__ = [ |
| "ImplicitPointRendPointSupHead", |
| "MaskRCNNConvUpsamplePointSupHead", |
| ] |
|
|
|
|
| @ROI_MASK_HEAD_REGISTRY.register() |
| class MaskRCNNConvUpsamplePointSupHead(MaskRCNNConvUpsampleHead): |
| """ |
| A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). |
| Predictions are made with a final 1x1 conv layer. |
| |
| The difference with `MaskRCNNConvUpsampleHead` is that this head is trained |
| with point supervision. Please use the `MaskRCNNConvUpsampleHead` if you want |
| to train the model with mask supervision. |
| """ |
|
|
| def forward(self, x, instances: List[Instances]) -> Any: |
| """ |
| Args: |
| x: input region feature(s) provided by :class:`ROIHeads`. |
| instances (list[Instances]): contains the boxes & labels corresponding |
| to the input features. |
| Exact format is up to its caller to decide. |
| Typically, this is the foreground instances in training, with |
| "proposal_boxes" field and other gt annotations. |
| In inference, it contains boxes that are already predicted. |
| Returns: |
| A dict of losses in training. The predicted "instances" in inference. |
| """ |
| x = self.layers(x) |
| if self.training: |
| N, C, H, W = x.shape |
| assert H == W |
|
|
| proposal_boxes = [x.proposal_boxes for x in instances] |
| assert N == np.sum(len(x) for x in proposal_boxes) |
|
|
| if N == 0: |
| return {"loss_mask": x.sum() * 0} |
|
|
| |
| point_coords, point_labels = get_point_coords_from_point_annotation(instances) |
|
|
| mask_logits = point_sample( |
| x, |
| point_coords, |
| align_corners=False, |
| ) |
|
|
| return {"loss_mask": roi_mask_point_loss(mask_logits, instances, point_labels)} |
| else: |
| mask_rcnn_inference(x, instances) |
| return instances |
|
|
|
|
| @ROI_MASK_HEAD_REGISTRY.register() |
| class ImplicitPointRendPointSupHead(ImplicitPointRendMaskHead): |
| def _uniform_sample_train_points(self, instances): |
| assert self.training |
| |
| point_coords, point_labels = get_point_coords_from_point_annotation(instances) |
|
|
| return point_coords, point_labels |
|
|