| | |
| |
|
| | from typing import Any, Dict, List, Tuple |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.structures import BoxMode, Instances |
| |
|
| | from densepose.converters import ToChartResultConverter |
| | from densepose.converters.base import IntTupleBox, make_int_box |
| | from densepose.structures import DensePoseDataRelative, DensePoseList |
| |
|
| |
|
| | class DensePoseBaseSampler: |
| | """ |
| | Base DensePose sampler to produce DensePose data from DensePose predictions. |
| | Samples for each class are drawn according to some distribution over all pixels estimated |
| | to belong to that class. |
| | """ |
| |
|
| | def __init__(self, count_per_class: int = 8): |
| | """ |
| | Constructor |
| | |
| | Args: |
| | count_per_class (int): the sampler produces at most `count_per_class` |
| | samples for each category |
| | """ |
| | self.count_per_class = count_per_class |
| |
|
| | def __call__(self, instances: Instances) -> DensePoseList: |
| | """ |
| | Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`) |
| | into DensePose annotations data (an instance of `DensePoseList`) |
| | """ |
| | boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu() |
| | boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) |
| | dp_datas = [] |
| | for i in range(len(boxes_xywh_abs)): |
| | annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i])) |
| | annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( |
| | instances[i].pred_densepose |
| | ) |
| | dp_datas.append(DensePoseDataRelative(annotation_i)) |
| | |
| | dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size) |
| | return dp_list |
| |
|
| | def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: |
| | """ |
| | Sample DensPoseDataRelative from estimation results |
| | """ |
| | labels, dp_result = self._produce_labels_and_results(instance) |
| | annotation = { |
| | DensePoseDataRelative.X_KEY: [], |
| | DensePoseDataRelative.Y_KEY: [], |
| | DensePoseDataRelative.U_KEY: [], |
| | DensePoseDataRelative.V_KEY: [], |
| | DensePoseDataRelative.I_KEY: [], |
| | } |
| | n, h, w = dp_result.shape |
| | for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1): |
| | |
| | |
| | |
| | |
| | indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True) |
| | |
| | |
| | |
| | values = dp_result[indices].view(n, -1) |
| | k = values.shape[1] |
| | count = min(self.count_per_class, k) |
| | if count <= 0: |
| | continue |
| | index_sample = self._produce_index_sample(values, count) |
| | sampled_values = values[:, index_sample] |
| | sampled_y = indices[1][index_sample] + 0.5 |
| | sampled_x = indices[2][index_sample] + 0.5 |
| | |
| | x = (sampled_x / w * 256.0).cpu().tolist() |
| | y = (sampled_y / h * 256.0).cpu().tolist() |
| | u = sampled_values[0].clamp(0, 1).cpu().tolist() |
| | v = sampled_values[1].clamp(0, 1).cpu().tolist() |
| | fine_segm_labels = [part_id] * count |
| | |
| | annotation[DensePoseDataRelative.X_KEY].extend(x) |
| | annotation[DensePoseDataRelative.Y_KEY].extend(y) |
| | annotation[DensePoseDataRelative.U_KEY].extend(u) |
| | annotation[DensePoseDataRelative.V_KEY].extend(v) |
| | annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels) |
| | return annotation |
| |
|
| | def _produce_index_sample(self, values: torch.Tensor, count: int): |
| | """ |
| | Abstract method to produce a sample of indices to select data |
| | To be implemented in descendants |
| | |
| | Args: |
| | values (torch.Tensor): an array of size [n, k] that contains |
| | estimated values (U, V, confidences); |
| | n: number of channels (U, V, confidences) |
| | k: number of points labeled with part_id |
| | count (int): number of samples to produce, should be positive and <= k |
| | |
| | Return: |
| | list(int): indices of values (along axis 1) selected as a sample |
| | """ |
| | raise NotImplementedError |
| |
|
| | def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Method to get labels and DensePose results from an instance |
| | |
| | Args: |
| | instance (Instances): an instance of `DensePoseChartPredictorOutput` |
| | |
| | Return: |
| | labels (torch.Tensor): shape [H, W], DensePose segmentation labels |
| | dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v |
| | """ |
| | converter = ToChartResultConverter |
| | chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes) |
| | labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu() |
| | return labels, dp_result |
| |
|
| | def _resample_mask(self, output: Any) -> torch.Tensor: |
| | """ |
| | Convert DensePose predictor output to segmentation annotation - tensors of size |
| | (256, 256) and type `int64`. |
| | |
| | Args: |
| | output: DensePose predictor output with the following attributes: |
| | - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse |
| | segmentation scores |
| | - fine_segm: tensor of size [N, C, H, W] with unnormalized fine |
| | segmentation scores |
| | Return: |
| | Tensor of size (S, S) and type `int64` with coarse segmentation annotations, |
| | where S = DensePoseDataRelative.MASK_SIZE |
| | """ |
| | sz = DensePoseDataRelative.MASK_SIZE |
| | S = ( |
| | F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) |
| | .argmax(dim=1) |
| | .long() |
| | ) |
| | I = ( |
| | ( |
| | F.interpolate( |
| | output.fine_segm, |
| | (sz, sz), |
| | mode="bilinear", |
| | align_corners=False, |
| | ).argmax(dim=1) |
| | * (S > 0).long() |
| | ) |
| | .squeeze() |
| | .cpu() |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | FINE_TO_COARSE_SEGMENTATION = { |
| | 1: 1, |
| | 2: 1, |
| | 3: 2, |
| | 4: 3, |
| | 5: 4, |
| | 6: 5, |
| | 7: 6, |
| | 8: 7, |
| | 9: 6, |
| | 10: 7, |
| | 11: 8, |
| | 12: 9, |
| | 13: 8, |
| | 14: 9, |
| | 15: 10, |
| | 16: 11, |
| | 17: 10, |
| | 18: 11, |
| | 19: 12, |
| | 20: 13, |
| | 21: 12, |
| | 22: 13, |
| | 23: 14, |
| | 24: 14, |
| | } |
| | mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu")) |
| | for i in range(DensePoseDataRelative.N_PART_LABELS): |
| | mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1] |
| | return mask |
| |
|