|
|
|
|
|
|
|
|
import unittest |
|
|
import torch |
|
|
|
|
|
from detectron2.structures import Boxes, BoxMode, Instances |
|
|
|
|
|
from densepose.modeling.losses.utils import ChartBasedAnnotationsAccumulator |
|
|
from densepose.structures import DensePoseDataRelative, DensePoseList |
|
|
|
|
|
image_shape = (100, 100) |
|
|
instances = Instances(image_shape) |
|
|
n_instances = 3 |
|
|
instances.proposal_boxes = Boxes(torch.rand(n_instances, 4)) |
|
|
instances.gt_boxes = Boxes(torch.rand(n_instances, 4)) |
|
|
|
|
|
|
|
|
|
|
|
class TestChartBasedAnnotationsAccumulator(unittest.TestCase): |
|
|
def test_chart_based_annotations_accumulator_no_gt_densepose(self): |
|
|
accumulator = ChartBasedAnnotationsAccumulator() |
|
|
accumulator.accumulate(instances) |
|
|
expected_values = {"nxt_bbox_with_dp_index": 0, "nxt_bbox_index": n_instances} |
|
|
for key in accumulator.__dict__: |
|
|
self.assertEqual(getattr(accumulator, key), expected_values.get(key, [])) |
|
|
|
|
|
def test_chart_based_annotations_accumulator_gt_densepose_none(self): |
|
|
instances.gt_densepose = [None] * n_instances |
|
|
accumulator = ChartBasedAnnotationsAccumulator() |
|
|
accumulator.accumulate(instances) |
|
|
expected_values = {"nxt_bbox_with_dp_index": 0, "nxt_bbox_index": n_instances} |
|
|
for key in accumulator.__dict__: |
|
|
self.assertEqual(getattr(accumulator, key), expected_values.get(key, [])) |
|
|
|
|
|
def test_chart_based_annotations_accumulator_gt_densepose(self): |
|
|
data_relative_keys = [ |
|
|
DensePoseDataRelative.X_KEY, |
|
|
DensePoseDataRelative.Y_KEY, |
|
|
DensePoseDataRelative.I_KEY, |
|
|
DensePoseDataRelative.U_KEY, |
|
|
DensePoseDataRelative.V_KEY, |
|
|
DensePoseDataRelative.S_KEY, |
|
|
] |
|
|
annotations = [DensePoseDataRelative({k: [0] for k in data_relative_keys})] * n_instances |
|
|
instances.gt_densepose = DensePoseList(annotations, instances.gt_boxes, image_shape) |
|
|
accumulator = ChartBasedAnnotationsAccumulator() |
|
|
accumulator.accumulate(instances) |
|
|
bbox_xywh_est = BoxMode.convert( |
|
|
instances.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
|
|
) |
|
|
bbox_xywh_gt = BoxMode.convert( |
|
|
instances.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS |
|
|
) |
|
|
expected_values = { |
|
|
"s_gt": [ |
|
|
torch.zeros((3, DensePoseDataRelative.MASK_SIZE, DensePoseDataRelative.MASK_SIZE)) |
|
|
] |
|
|
* n_instances, |
|
|
"bbox_xywh_est": bbox_xywh_est.split(1), |
|
|
"bbox_xywh_gt": bbox_xywh_gt.split(1), |
|
|
"point_bbox_with_dp_indices": [torch.tensor([i]) for i in range(n_instances)], |
|
|
"point_bbox_indices": [torch.tensor([i]) for i in range(n_instances)], |
|
|
"bbox_indices": list(range(n_instances)), |
|
|
"nxt_bbox_with_dp_index": n_instances, |
|
|
"nxt_bbox_index": n_instances, |
|
|
} |
|
|
default_value = [torch.tensor([0])] * 3 |
|
|
for key in accumulator.__dict__: |
|
|
to_test = getattr(accumulator, key) |
|
|
gt_value = expected_values.get(key, default_value) |
|
|
if key in ["nxt_bbox_with_dp_index", "nxt_bbox_index"]: |
|
|
self.assertEqual(to_test, gt_value) |
|
|
elif key == "bbox_indices": |
|
|
self.assertListEqual(to_test, gt_value) |
|
|
else: |
|
|
self.assertTrue(torch.allclose(torch.stack(to_test), torch.stack(gt_value))) |
|
|
|