| |
|
|
| from dataclasses import dataclass |
| from typing import Any, Optional |
| import torch |
|
|
| from detectron2.structures import BoxMode, Instances |
|
|
| from .utils import AnnotationsAccumulator |
|
|
|
|
| @dataclass |
| class PackedCseAnnotations: |
| x_gt: torch.Tensor |
| y_gt: torch.Tensor |
| coarse_segm_gt: Optional[torch.Tensor] |
| vertex_mesh_ids_gt: torch.Tensor |
| vertex_ids_gt: torch.Tensor |
| bbox_xywh_gt: torch.Tensor |
| bbox_xywh_est: torch.Tensor |
| point_bbox_with_dp_indices: torch.Tensor |
| point_bbox_indices: torch.Tensor |
| bbox_indices: torch.Tensor |
|
|
|
|
| class CseAnnotationsAccumulator(AnnotationsAccumulator): |
| """ |
| Accumulates annotations by batches that correspond to objects detected on |
| individual images. Can pack them together into single tensors. |
| """ |
|
|
| def __init__(self): |
| self.x_gt = [] |
| self.y_gt = [] |
| self.s_gt = [] |
| self.vertex_mesh_ids_gt = [] |
| self.vertex_ids_gt = [] |
| self.bbox_xywh_gt = [] |
| self.bbox_xywh_est = [] |
| self.point_bbox_with_dp_indices = [] |
| self.point_bbox_indices = [] |
| self.bbox_indices = [] |
| self.nxt_bbox_with_dp_index = 0 |
| self.nxt_bbox_index = 0 |
|
|
| def accumulate(self, instances_one_image: Instances): |
| """ |
| Accumulate instances data for one image |
| |
| Args: |
| instances_one_image (Instances): instances data to accumulate |
| """ |
| boxes_xywh_est = BoxMode.convert( |
| instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
| ) |
| boxes_xywh_gt = BoxMode.convert( |
| instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
| ) |
| n_matches = len(boxes_xywh_gt) |
| assert n_matches == len( |
| boxes_xywh_est |
| ), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes" |
| if not n_matches: |
| |
| return |
| if ( |
| not hasattr(instances_one_image, "gt_densepose") |
| or instances_one_image.gt_densepose is None |
| ): |
| |
| self.nxt_bbox_index += n_matches |
| return |
| for box_xywh_est, box_xywh_gt, dp_gt in zip( |
| boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose |
| ): |
| if (dp_gt is not None) and (len(dp_gt.x) > 0): |
| |
| |
| self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt) |
| self.nxt_bbox_index += 1 |
|
|
| def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any): |
| """ |
| Accumulate instances data for one image, given that the data is not empty |
| |
| Args: |
| box_xywh_gt (tensor): GT bounding box |
| box_xywh_est (tensor): estimated bounding box |
| dp_gt: GT densepose data with the following attributes: |
| - x: normalized X coordinates |
| - y: normalized Y coordinates |
| - segm: tensor of size [S, S] with coarse segmentation |
| - |
| """ |
| self.x_gt.append(dp_gt.x) |
| self.y_gt.append(dp_gt.y) |
| if hasattr(dp_gt, "segm"): |
| self.s_gt.append(dp_gt.segm.unsqueeze(0)) |
| self.vertex_ids_gt.append(dp_gt.vertex_ids) |
| self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id)) |
| self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4)) |
| self.bbox_xywh_est.append(box_xywh_est.view(-1, 4)) |
| self.point_bbox_with_dp_indices.append( |
| torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index) |
| ) |
| self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index)) |
| self.bbox_indices.append(self.nxt_bbox_index) |
| self.nxt_bbox_with_dp_index += 1 |
|
|
| def pack(self) -> Optional[PackedCseAnnotations]: |
| """ |
| Pack data into tensors |
| """ |
| if not len(self.x_gt): |
| |
| |
| |
| |
| |
| return None |
| return PackedCseAnnotations( |
| x_gt=torch.cat(self.x_gt, 0), |
| y_gt=torch.cat(self.y_gt, 0), |
| vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0), |
| vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0), |
| |
| coarse_segm_gt=torch.cat(self.s_gt, 0) |
| if len(self.s_gt) == len(self.bbox_xywh_gt) |
| else None, |
| bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0), |
| bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0), |
| point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0), |
| point_bbox_indices=torch.cat(self.point_bbox_indices, 0), |
| bbox_indices=torch.as_tensor( |
| self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device |
| ), |
| ) |
|
|