| |
| from typing import List |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from annotator.oneformer.detectron2.config import configurable |
| from annotator.oneformer.detectron2.layers import Conv2d, ConvTranspose2d, cat, interpolate |
| from annotator.oneformer.detectron2.structures import Instances, heatmaps_to_keypoints |
| from annotator.oneformer.detectron2.utils.events import get_event_storage |
| from annotator.oneformer.detectron2.utils.registry import Registry |
|
|
| _TOTAL_SKIPPED = 0 |
|
|
|
|
| __all__ = [ |
| "ROI_KEYPOINT_HEAD_REGISTRY", |
| "build_keypoint_head", |
| "BaseKeypointRCNNHead", |
| "KRCNNConvDeconvUpsampleHead", |
| ] |
|
|
|
|
| ROI_KEYPOINT_HEAD_REGISTRY = Registry("ROI_KEYPOINT_HEAD") |
| ROI_KEYPOINT_HEAD_REGISTRY.__doc__ = """ |
| Registry for keypoint heads, which make keypoint predictions from per-region features. |
| |
| The registered object will be called with `obj(cfg, input_shape)`. |
| """ |
|
|
|
|
| def build_keypoint_head(cfg, input_shape): |
| """ |
| Build a keypoint head from `cfg.MODEL.ROI_KEYPOINT_HEAD.NAME`. |
| """ |
| name = cfg.MODEL.ROI_KEYPOINT_HEAD.NAME |
| return ROI_KEYPOINT_HEAD_REGISTRY.get(name)(cfg, input_shape) |
|
|
|
|
| def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): |
| """ |
| Arguments: |
| pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number |
| of instances in the batch, K is the number of keypoints, and S is the side length |
| of the keypoint heatmap. The values are spatial logits. |
| instances (list[Instances]): A list of M Instances, where M is the batch size. |
| These instances are predictions from the model |
| that are in 1:1 correspondence with pred_keypoint_logits. |
| Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint` |
| instance. |
| normalizer (float): Normalize the loss by this amount. |
| If not specified, we normalize by the number of visible keypoints in the minibatch. |
| |
| Returns a scalar tensor containing the loss. |
| """ |
| heatmaps = [] |
| valid = [] |
|
|
| keypoint_side_len = pred_keypoint_logits.shape[2] |
| for instances_per_image in instances: |
| if len(instances_per_image) == 0: |
| continue |
| keypoints = instances_per_image.gt_keypoints |
| heatmaps_per_image, valid_per_image = keypoints.to_heatmap( |
| instances_per_image.proposal_boxes.tensor, keypoint_side_len |
| ) |
| heatmaps.append(heatmaps_per_image.view(-1)) |
| valid.append(valid_per_image.view(-1)) |
|
|
| if len(heatmaps): |
| keypoint_targets = cat(heatmaps, dim=0) |
| valid = cat(valid, dim=0).to(dtype=torch.uint8) |
| valid = torch.nonzero(valid).squeeze(1) |
|
|
| |
| |
| if len(heatmaps) == 0 or valid.numel() == 0: |
| global _TOTAL_SKIPPED |
| _TOTAL_SKIPPED += 1 |
| storage = get_event_storage() |
| storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False) |
| return pred_keypoint_logits.sum() * 0 |
|
|
| N, K, H, W = pred_keypoint_logits.shape |
| pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) |
|
|
| keypoint_loss = F.cross_entropy( |
| pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" |
| ) |
|
|
| |
| if normalizer is None: |
| normalizer = valid.numel() |
| keypoint_loss /= normalizer |
|
|
| return keypoint_loss |
|
|
|
|
| def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]): |
| """ |
| Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) |
| and add it to the `pred_instances` as a `pred_keypoints` field. |
| |
| Args: |
| pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number |
| of instances in the batch, K is the number of keypoints, and S is the side length of |
| the keypoint heatmap. The values are spatial logits. |
| pred_instances (list[Instances]): A list of N Instances, where N is the number of images. |
| |
| Returns: |
| None. Each element in pred_instances will contain extra "pred_keypoints" and |
| "pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape |
| (#instance, K, 3) where the last dimension corresponds to (x, y, score). |
| The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw |
| keypoint logits as passed to this function. |
| """ |
| |
| bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) |
|
|
| pred_keypoint_logits = pred_keypoint_logits.detach() |
| keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach()) |
| num_instances_per_image = [len(i) for i in pred_instances] |
| keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0) |
| heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0) |
|
|
| for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip( |
| keypoint_results, heatmap_results, pred_instances |
| ): |
| |
| |
| instances_per_image.pred_keypoints = keypoint_results_per_image |
| instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image |
|
|
|
|
| class BaseKeypointRCNNHead(nn.Module): |
| """ |
| Implement the basic Keypoint R-CNN losses and inference logic described in |
| Sec. 5 of :paper:`Mask R-CNN`. |
| """ |
|
|
| @configurable |
| def __init__(self, *, num_keypoints, loss_weight=1.0, loss_normalizer=1.0): |
| """ |
| NOTE: this interface is experimental. |
| |
| Args: |
| num_keypoints (int): number of keypoints to predict |
| loss_weight (float): weight to multiple on the keypoint loss |
| loss_normalizer (float or str): |
| If float, divide the loss by `loss_normalizer * #images`. |
| If 'visible', the loss is normalized by the total number of |
| visible keypoints across images. |
| """ |
| super().__init__() |
| self.num_keypoints = num_keypoints |
| self.loss_weight = loss_weight |
| assert loss_normalizer == "visible" or isinstance(loss_normalizer, float), loss_normalizer |
| self.loss_normalizer = loss_normalizer |
|
|
| @classmethod |
| def from_config(cls, cfg, input_shape): |
| ret = { |
| "loss_weight": cfg.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT, |
| "num_keypoints": cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS, |
| } |
| normalize_by_visible = ( |
| cfg.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS |
| ) |
| if not normalize_by_visible: |
| batch_size_per_image = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE |
| positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION |
| ret["loss_normalizer"] = ( |
| ret["num_keypoints"] * batch_size_per_image * positive_sample_fraction |
| ) |
| else: |
| ret["loss_normalizer"] = "visible" |
| return ret |
|
|
| def forward(self, x, instances: List[Instances]): |
| """ |
| Args: |
| x: input 4D region feature(s) provided by :class:`ROIHeads`. |
| instances (list[Instances]): contains the boxes & labels corresponding |
| to the input features. |
| Exact format is up to its caller to decide. |
| Typically, this is the foreground instances in training, with |
| "proposal_boxes" field and other gt annotations. |
| In inference, it contains boxes that are already predicted. |
| |
| Returns: |
| A dict of losses if in training. The predicted "instances" if in inference. |
| """ |
| x = self.layers(x) |
| if self.training: |
| num_images = len(instances) |
| normalizer = ( |
| None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer |
| ) |
| return { |
| "loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer) |
| * self.loss_weight |
| } |
| else: |
| keypoint_rcnn_inference(x, instances) |
| return instances |
|
|
| def layers(self, x): |
| """ |
| Neural network layers that makes predictions from regional input features. |
| """ |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
| @ROI_KEYPOINT_HEAD_REGISTRY.register() |
| class KRCNNConvDeconvUpsampleHead(BaseKeypointRCNNHead, nn.Sequential): |
| """ |
| A standard keypoint head containing a series of 3x3 convs, followed by |
| a transpose convolution and bilinear interpolation for upsampling. |
| It is described in Sec. 5 of :paper:`Mask R-CNN`. |
| """ |
|
|
| @configurable |
| def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs): |
| """ |
| NOTE: this interface is experimental. |
| |
| Args: |
| input_shape (ShapeSpec): shape of the input feature |
| conv_dims: an iterable of output channel counts for each conv in the head |
| e.g. (512, 512, 512) for three convs outputting 512 channels. |
| """ |
| super().__init__(num_keypoints=num_keypoints, **kwargs) |
|
|
| |
| up_scale = 2.0 |
| in_channels = input_shape.channels |
|
|
| for idx, layer_channels in enumerate(conv_dims, 1): |
| module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1) |
| self.add_module("conv_fcn{}".format(idx), module) |
| self.add_module("conv_fcn_relu{}".format(idx), nn.ReLU()) |
| in_channels = layer_channels |
|
|
| deconv_kernel = 4 |
| self.score_lowres = ConvTranspose2d( |
| in_channels, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1 |
| ) |
| self.up_scale = up_scale |
|
|
| for name, param in self.named_parameters(): |
| if "bias" in name: |
| nn.init.constant_(param, 0) |
| elif "weight" in name: |
| |
| |
| nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") |
|
|
| @classmethod |
| def from_config(cls, cfg, input_shape): |
| ret = super().from_config(cfg, input_shape) |
| ret["input_shape"] = input_shape |
| ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS |
| return ret |
|
|
| def layers(self, x): |
| for layer in self: |
| x = layer(x) |
| x = interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False) |
| return x |
|
|