| | |
| |
|
| | from typing import Any, Dict, List, Tuple |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.config import CfgNode |
| | from detectron2.structures import Instances |
| |
|
| | from densepose.converters.base import IntTupleBox |
| | from densepose.data.utils import get_class_to_mesh_name_mapping |
| | from densepose.modeling.cse.utils import squared_euclidean_distance_matrix |
| | from densepose.structures import DensePoseDataRelative |
| |
|
| | from .densepose_base import DensePoseBaseSampler |
| |
|
| |
|
| | class DensePoseCSEBaseSampler(DensePoseBaseSampler): |
| | """ |
| | Base DensePose sampler to produce DensePose data from DensePose predictions. |
| | Samples for each class are drawn according to some distribution over all pixels estimated |
| | to belong to that class. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | cfg: CfgNode, |
| | use_gt_categories: bool, |
| | embedder: torch.nn.Module, |
| | count_per_class: int = 8, |
| | ): |
| | """ |
| | Constructor |
| | |
| | Args: |
| | cfg (CfgNode): the config of the model |
| | embedder (torch.nn.Module): necessary to compute mesh vertex embeddings |
| | count_per_class (int): the sampler produces at most `count_per_class` |
| | samples for each category |
| | """ |
| | super().__init__(count_per_class) |
| | self.embedder = embedder |
| | self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg) |
| | self.use_gt_categories = use_gt_categories |
| |
|
| | def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]: |
| | """ |
| | Sample DensPoseDataRelative from estimation results |
| | """ |
| | if self.use_gt_categories: |
| | instance_class = instance.dataset_classes.tolist()[0] |
| | else: |
| | instance_class = instance.pred_classes.tolist()[0] |
| | mesh_name = self.class_to_mesh_name[instance_class] |
| |
|
| | annotation = { |
| | DensePoseDataRelative.X_KEY: [], |
| | DensePoseDataRelative.Y_KEY: [], |
| | DensePoseDataRelative.VERTEX_IDS_KEY: [], |
| | DensePoseDataRelative.MESH_NAME_KEY: mesh_name, |
| | } |
| |
|
| | mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh) |
| | indices = torch.nonzero(mask, as_tuple=True) |
| | selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu() |
| | values = other_values[:, indices[0], indices[1]] |
| | k = values.shape[1] |
| |
|
| | count = min(self.count_per_class, k) |
| | if count <= 0: |
| | return annotation |
| |
|
| | index_sample = self._produce_index_sample(values, count) |
| | closest_vertices = squared_euclidean_distance_matrix( |
| | selected_embeddings[index_sample], self.embedder(mesh_name) |
| | ) |
| | closest_vertices = torch.argmin(closest_vertices, dim=1) |
| |
|
| | sampled_y = indices[0][index_sample] + 0.5 |
| | sampled_x = indices[1][index_sample] + 0.5 |
| | |
| | _, _, w, h = bbox_xywh |
| | x = (sampled_x / w * 256.0).cpu().tolist() |
| | y = (sampled_y / h * 256.0).cpu().tolist() |
| | |
| | annotation[DensePoseDataRelative.X_KEY].extend(x) |
| | annotation[DensePoseDataRelative.Y_KEY].extend(y) |
| | annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist()) |
| | return annotation |
| |
|
| | def _produce_mask_and_results( |
| | self, instance: Instances, bbox_xywh: IntTupleBox |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Method to get labels and DensePose results from an instance |
| | |
| | Args: |
| | instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput` |
| | bbox_xywh (IntTupleBox): the corresponding bounding box |
| | |
| | Return: |
| | mask (torch.Tensor): shape [H, W], DensePose segmentation mask |
| | embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W], |
| | DensePose CSE Embeddings |
| | other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W], |
| | for potential other values |
| | """ |
| | densepose_output = instance.pred_densepose |
| | S = densepose_output.coarse_segm |
| | E = densepose_output.embedding |
| | _, _, w, h = bbox_xywh |
| | embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0] |
| | coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0] |
| | mask = coarse_segm_resized.argmax(0) > 0 |
| | other_values = torch.empty((0, h, w), device=E.device) |
| | return mask, embeddings, other_values |
| |
|
| | def _resample_mask(self, output: Any) -> torch.Tensor: |
| | """ |
| | Convert DensePose predictor output to segmentation annotation - tensors of size |
| | (256, 256) and type `int64`. |
| | |
| | Args: |
| | output: DensePose predictor output with the following attributes: |
| | - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse |
| | segmentation scores |
| | Return: |
| | Tensor of size (S, S) and type `int64` with coarse segmentation annotations, |
| | where S = DensePoseDataRelative.MASK_SIZE |
| | """ |
| | sz = DensePoseDataRelative.MASK_SIZE |
| | mask = ( |
| | F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False) |
| | .argmax(dim=1) |
| | .long() |
| | .squeeze() |
| | .cpu() |
| | ) |
| | return mask |
| |
|