# Copyright (c) Facebook, Inc. and its affiliates. # pyre-unsafe from typing import Any, Dict, List, Tuple import torch from torch.nn import functional as F from ....detectron2.structures import BoxMode, Instances from ...converters import ToChartResultConverter from ...converters.base import IntTupleBox, make_int_box from ...structures import DensePoseDataRelative, DensePoseList class DensePoseBaseSampler: """ Base DensePose sampler to produce DensePose data from parse_utils.densepose predictions. Samples for each class are drawn according to some distribution over all pixels estimated to belong to that class. """ def __init__(self, count_per_class: int = 8): """ Constructor Args: count_per_class (int): the sampler produces at most `count_per_class` samples for each category """ self.count_per_class = count_per_class def __call__(self, instances: Instances) -> DensePoseList: """ Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`) into DensePose annotations data (an instance of `DensePoseList`) """ boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu() boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) dp_datas = [] for i in range(len(boxes_xywh_abs)): annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i])) annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask( # pyre-ignore[6] instances[i].pred_densepose ) dp_datas.append(DensePoseDataRelative(annotation_i)) # create densepose annotations on CPU dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size) return dp_list def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: """ Sample DensPoseDataRelative from estimation results """ labels, dp_result = self._produce_labels_and_results(instance) annotation = { DensePoseDataRelative.X_KEY: [], DensePoseDataRelative.Y_KEY: [], DensePoseDataRelative.U_KEY: [], DensePoseDataRelative.V_KEY: [], DensePoseDataRelative.I_KEY: [], } n, h, w = dp_result.shape for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1): # indices - tuple of 3 1D tensors of size k # 0: index along the first dimension N # 1: index along H dimension # 2: index along W dimension indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True) # values - an array of size [n, k] # n: number of channels (U, V, confidences) # k: number of points labeled with part_id values = dp_result[indices].view(n, -1) k = values.shape[1] count = min(self.count_per_class, k) if count <= 0: continue index_sample = self._produce_index_sample(values, count) sampled_values = values[:, index_sample] sampled_y = indices[1][index_sample] + 0.5 sampled_x = indices[2][index_sample] + 0.5 # prepare / normalize data x = (sampled_x / w * 256.0).cpu().tolist() y = (sampled_y / h * 256.0).cpu().tolist() u = sampled_values[0].clamp(0, 1).cpu().tolist() v = sampled_values[1].clamp(0, 1).cpu().tolist() fine_segm_labels = [part_id] * count # extend annotations annotation[DensePoseDataRelative.X_KEY].extend(x) annotation[DensePoseDataRelative.Y_KEY].extend(y) annotation[DensePoseDataRelative.U_KEY].extend(u) annotation[DensePoseDataRelative.V_KEY].extend(v) annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels) return annotation def _produce_index_sample(self, values: torch.Tensor, count: int): """ Abstract method to produce a sample of indices to select data To be implemented in descendants Args: values (torch.Tensor): an array of size [n, k] that contains estimated values (U, V, confidences); n: number of channels (U, V, confidences) k: number of points labeled with part_id count (int): number of samples to produce, should be positive and <= k Return: list(int): indices of values (along axis 1) selected as a sample """ raise NotImplementedError def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]: """ Method to get labels and DensePose results from an instance Args: instance (Instances): an instance of `DensePoseChartPredictorOutput` Return: labels (torch.Tensor): shape [H, W], DensePose segmentation labels dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v """ converter = ToChartResultConverter chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes) labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu() return labels, dp_result def _resample_mask(self, output: Any) -> torch.Tensor: """ Convert DensePose predictor output to segmentation annotation - tensors of size (256, 256) and type `int64`. Args: output: DensePose predictor output with the following attributes: - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse segmentation scores - fine_segm: tensor of size [N, C, H, W] with unnormalized fine segmentation scores Return: Tensor of size (S, S) and type `int64` with coarse segmentation annotations, where S = DensePoseDataRelative.MASK_SIZE """ sz = DensePoseDataRelative.MASK_SIZE S = ( F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) .argmax(dim=1) .long() ) I = ( ( F.interpolate( output.fine_segm, (sz, sz), mode="bilinear", align_corners=False, ).argmax(dim=1) * (S > 0).long() ) .squeeze() .cpu() ) # Map fine segmentation results to coarse segmentation ground truth # TODO: extract this into separate classes # coarse segmentation: 1 = Torso, 2 = Right Hand, 3 = Left Hand, # 4 = Left Foot, 5 = Right Foot, 6 = Upper Leg Right, 7 = Upper Leg Left, # 8 = Lower Leg Right, 9 = Lower Leg Left, 10 = Upper Arm Left, # 11 = Upper Arm Right, 12 = Lower Arm Left, 13 = Lower Arm Right, # 14 = Head # fine segmentation: 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand, # 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right, # 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right, # 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left, # 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left, # 20, 22 = Lower Arm Right, 23, 24 = Head FINE_TO_COARSE_SEGMENTATION = { 1: 1, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8: 7, 9: 6, 10: 7, 11: 8, 12: 9, 13: 8, 14: 9, 15: 10, 16: 11, 17: 10, 18: 11, 19: 12, 20: 13, 21: 12, 22: 13, 23: 14, 24: 14, } mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu")) for i in range(DensePoseDataRelative.N_PART_LABELS): mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1] return mask