| | |
| |
|
| | from typing import List |
| | import torch |
| |
|
| | from detectron2.config import CfgNode |
| | from detectron2.structures import Instances |
| | from detectron2.structures.boxes import matched_pairwise_iou |
| |
|
| |
|
| | class DensePoseDataFilter: |
| | def __init__(self, cfg: CfgNode): |
| | self.iou_threshold = cfg.MODEL.ROI_DENSEPOSE_HEAD.FG_IOU_THRESHOLD |
| | self.keep_masks = cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS |
| |
|
| | @torch.no_grad() |
| | def __call__(self, features: List[torch.Tensor], proposals_with_targets: List[Instances]): |
| | """ |
| | Filters proposals with targets to keep only the ones relevant for |
| | DensePose training |
| | |
| | Args: |
| | features (list[Tensor]): input data as a list of features, |
| | each feature is a tensor. Axis 0 represents the number of |
| | images `N` in the input data; axes 1-3 are channels, |
| | height, and width, which may vary between features |
| | (e.g., if a feature pyramid is used). |
| | proposals_with_targets (list[Instances]): length `N` list of |
| | `Instances`. The i-th `Instances` contains instances |
| | (proposals, GT) for the i-th input image, |
| | Returns: |
| | list[Tensor]: filtered features |
| | list[Instances]: filtered proposals |
| | """ |
| | proposals_filtered = [] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for i, proposals_per_image in enumerate(proposals_with_targets): |
| | if not proposals_per_image.has("gt_densepose") and ( |
| | not proposals_per_image.has("gt_masks") or not self.keep_masks |
| | ): |
| | |
| | continue |
| | gt_boxes = proposals_per_image.gt_boxes |
| | est_boxes = proposals_per_image.proposal_boxes |
| | |
| | iou = matched_pairwise_iou(gt_boxes, est_boxes) |
| | iou_select = iou > self.iou_threshold |
| | proposals_per_image = proposals_per_image[iou_select] |
| |
|
| | N_gt_boxes = len(proposals_per_image.gt_boxes) |
| | assert N_gt_boxes == len(proposals_per_image.proposal_boxes), ( |
| | f"The number of GT boxes {N_gt_boxes} is different from the " |
| | f"number of proposal boxes {len(proposals_per_image.proposal_boxes)}" |
| | ) |
| | |
| | if self.keep_masks: |
| | gt_masks = ( |
| | proposals_per_image.gt_masks |
| | if hasattr(proposals_per_image, "gt_masks") |
| | else [None] * N_gt_boxes |
| | ) |
| | else: |
| | gt_masks = [None] * N_gt_boxes |
| | gt_densepose = ( |
| | proposals_per_image.gt_densepose |
| | if hasattr(proposals_per_image, "gt_densepose") |
| | else [None] * N_gt_boxes |
| | ) |
| | assert len(gt_masks) == N_gt_boxes |
| | assert len(gt_densepose) == N_gt_boxes |
| | selected_indices = [ |
| | i |
| | for i, (dp_target, mask_target) in enumerate(zip(gt_densepose, gt_masks)) |
| | if (dp_target is not None) or (mask_target is not None) |
| | ] |
| | |
| | |
| | |
| | if len(selected_indices) != N_gt_boxes: |
| | proposals_per_image = proposals_per_image[selected_indices] |
| | assert len(proposals_per_image.gt_boxes) == len(proposals_per_image.proposal_boxes) |
| | proposals_filtered.append(proposals_per_image) |
| | |
| | |
| | return features, proposals_filtered |
| |
|