|
|
|
|
| from typing import Any, Dict, List, Tuple
|
| import torch
|
| from torch.nn import functional as F
|
|
|
| from detectron2.structures import BoxMode, Instances
|
|
|
| from densepose.converters import ToChartResultConverter
|
| from densepose.converters.base import IntTupleBox, make_int_box
|
| from densepose.structures import DensePoseDataRelative, DensePoseList
|
|
|
|
|
| class DensePoseBaseSampler:
|
| """
|
| Base DensePose sampler to produce DensePose data from DensePose predictions.
|
| Samples for each class are drawn according to some distribution over all pixels estimated
|
| to belong to that class.
|
| """
|
|
|
| def __init__(self, count_per_class: int = 8):
|
| """
|
| Constructor
|
|
|
| Args:
|
| count_per_class (int): the sampler produces at most `count_per_class`
|
| samples for each category
|
| """
|
| self.count_per_class = count_per_class
|
|
|
| def __call__(self, instances: Instances) -> DensePoseList:
|
| """
|
| Convert DensePose predictions (an instance of `DensePoseChartPredictorOutput`)
|
| into DensePose annotations data (an instance of `DensePoseList`)
|
| """
|
| boxes_xyxy_abs = instances.pred_boxes.tensor.clone().cpu()
|
| boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
|
| dp_datas = []
|
| for i in range(len(boxes_xywh_abs)):
|
| annotation_i = self._sample(instances[i], make_int_box(boxes_xywh_abs[i]))
|
| annotation_i[DensePoseDataRelative.S_KEY] = self._resample_mask(
|
| instances[i].pred_densepose
|
| )
|
| dp_datas.append(DensePoseDataRelative(annotation_i))
|
|
|
| dp_list = DensePoseList(dp_datas, boxes_xyxy_abs, instances.image_size)
|
| return dp_list
|
|
|
| def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
| """
|
| Sample DensPoseDataRelative from estimation results
|
| """
|
| labels, dp_result = self._produce_labels_and_results(instance)
|
| annotation = {
|
| DensePoseDataRelative.X_KEY: [],
|
| DensePoseDataRelative.Y_KEY: [],
|
| DensePoseDataRelative.U_KEY: [],
|
| DensePoseDataRelative.V_KEY: [],
|
| DensePoseDataRelative.I_KEY: [],
|
| }
|
| n, h, w = dp_result.shape
|
| for part_id in range(1, DensePoseDataRelative.N_PART_LABELS + 1):
|
|
|
|
|
|
|
|
|
| indices = torch.nonzero(labels.expand(n, h, w) == part_id, as_tuple=True)
|
|
|
|
|
|
|
| values = dp_result[indices].view(n, -1)
|
| k = values.shape[1]
|
| count = min(self.count_per_class, k)
|
| if count <= 0:
|
| continue
|
| index_sample = self._produce_index_sample(values, count)
|
| sampled_values = values[:, index_sample]
|
| sampled_y = indices[1][index_sample] + 0.5
|
| sampled_x = indices[2][index_sample] + 0.5
|
|
|
| x = (sampled_x / w * 256.0).cpu().tolist()
|
| y = (sampled_y / h * 256.0).cpu().tolist()
|
| u = sampled_values[0].clamp(0, 1).cpu().tolist()
|
| v = sampled_values[1].clamp(0, 1).cpu().tolist()
|
| fine_segm_labels = [part_id] * count
|
|
|
| annotation[DensePoseDataRelative.X_KEY].extend(x)
|
| annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
| annotation[DensePoseDataRelative.U_KEY].extend(u)
|
| annotation[DensePoseDataRelative.V_KEY].extend(v)
|
| annotation[DensePoseDataRelative.I_KEY].extend(fine_segm_labels)
|
| return annotation
|
|
|
| def _produce_index_sample(self, values: torch.Tensor, count: int):
|
| """
|
| Abstract method to produce a sample of indices to select data
|
| To be implemented in descendants
|
|
|
| Args:
|
| values (torch.Tensor): an array of size [n, k] that contains
|
| estimated values (U, V, confidences);
|
| n: number of channels (U, V, confidences)
|
| k: number of points labeled with part_id
|
| count (int): number of samples to produce, should be positive and <= k
|
|
|
| Return:
|
| list(int): indices of values (along axis 1) selected as a sample
|
| """
|
| raise NotImplementedError
|
|
|
| def _produce_labels_and_results(self, instance: Instances) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Method to get labels and DensePose results from an instance
|
|
|
| Args:
|
| instance (Instances): an instance of `DensePoseChartPredictorOutput`
|
|
|
| Return:
|
| labels (torch.Tensor): shape [H, W], DensePose segmentation labels
|
| dp_result (torch.Tensor): shape [2, H, W], stacked DensePose results u and v
|
| """
|
| converter = ToChartResultConverter
|
| chart_result = converter.convert(instance.pred_densepose, instance.pred_boxes)
|
| labels, dp_result = chart_result.labels.cpu(), chart_result.uv.cpu()
|
| return labels, dp_result
|
|
|
| def _resample_mask(self, output: Any) -> torch.Tensor:
|
| """
|
| Convert DensePose predictor output to segmentation annotation - tensors of size
|
| (256, 256) and type `int64`.
|
|
|
| Args:
|
| output: DensePose predictor output with the following attributes:
|
| - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
| segmentation scores
|
| - fine_segm: tensor of size [N, C, H, W] with unnormalized fine
|
| segmentation scores
|
| Return:
|
| Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
| where S = DensePoseDataRelative.MASK_SIZE
|
| """
|
| sz = DensePoseDataRelative.MASK_SIZE
|
| S = (
|
| F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
| .argmax(dim=1)
|
| .long()
|
| )
|
| I = (
|
| (
|
| F.interpolate(
|
| output.fine_segm,
|
| (sz, sz),
|
| mode="bilinear",
|
| align_corners=False,
|
| ).argmax(dim=1)
|
| * (S > 0).long()
|
| )
|
| .squeeze()
|
| .cpu()
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| FINE_TO_COARSE_SEGMENTATION = {
|
| 1: 1,
|
| 2: 1,
|
| 3: 2,
|
| 4: 3,
|
| 5: 4,
|
| 6: 5,
|
| 7: 6,
|
| 8: 7,
|
| 9: 6,
|
| 10: 7,
|
| 11: 8,
|
| 12: 9,
|
| 13: 8,
|
| 14: 9,
|
| 15: 10,
|
| 16: 11,
|
| 17: 10,
|
| 18: 11,
|
| 19: 12,
|
| 20: 13,
|
| 21: 12,
|
| 22: 13,
|
| 23: 14,
|
| 24: 14,
|
| }
|
| mask = torch.zeros((sz, sz), dtype=torch.int64, device=torch.device("cpu"))
|
| for i in range(DensePoseDataRelative.N_PART_LABELS):
|
| mask[I == i + 1] = FINE_TO_COARSE_SEGMENTATION[i + 1]
|
| return mask
|
|
|