|
|
|
|
| from typing import Any, Dict, List, Tuple
|
| import torch
|
| from torch.nn import functional as F
|
|
|
| from detectron2.config import CfgNode
|
| from detectron2.structures import Instances
|
|
|
| from densepose.converters.base import IntTupleBox
|
| from densepose.data.utils import get_class_to_mesh_name_mapping
|
| from densepose.modeling.cse.utils import squared_euclidean_distance_matrix
|
| from densepose.structures import DensePoseDataRelative
|
|
|
| from .densepose_base import DensePoseBaseSampler
|
|
|
|
|
| class DensePoseCSEBaseSampler(DensePoseBaseSampler):
|
| """
|
| Base DensePose sampler to produce DensePose data from DensePose predictions.
|
| Samples for each class are drawn according to some distribution over all pixels estimated
|
| to belong to that class.
|
| """
|
|
|
| def __init__(
|
| self,
|
| cfg: CfgNode,
|
| use_gt_categories: bool,
|
| embedder: torch.nn.Module,
|
| count_per_class: int = 8,
|
| ):
|
| """
|
| Constructor
|
|
|
| Args:
|
| cfg (CfgNode): the config of the model
|
| embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
|
| count_per_class (int): the sampler produces at most `count_per_class`
|
| samples for each category
|
| """
|
| super().__init__(count_per_class)
|
| self.embedder = embedder
|
| self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
|
| self.use_gt_categories = use_gt_categories
|
|
|
| def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
|
| """
|
| Sample DensPoseDataRelative from estimation results
|
| """
|
| if self.use_gt_categories:
|
| instance_class = instance.dataset_classes.tolist()[0]
|
| else:
|
| instance_class = instance.pred_classes.tolist()[0]
|
| mesh_name = self.class_to_mesh_name[instance_class]
|
|
|
| annotation = {
|
| DensePoseDataRelative.X_KEY: [],
|
| DensePoseDataRelative.Y_KEY: [],
|
| DensePoseDataRelative.VERTEX_IDS_KEY: [],
|
| DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
|
| }
|
|
|
| mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
|
| indices = torch.nonzero(mask, as_tuple=True)
|
| selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
|
| values = other_values[:, indices[0], indices[1]]
|
| k = values.shape[1]
|
|
|
| count = min(self.count_per_class, k)
|
| if count <= 0:
|
| return annotation
|
|
|
| index_sample = self._produce_index_sample(values, count)
|
| closest_vertices = squared_euclidean_distance_matrix(
|
| selected_embeddings[index_sample], self.embedder(mesh_name)
|
| )
|
| closest_vertices = torch.argmin(closest_vertices, dim=1)
|
|
|
| sampled_y = indices[0][index_sample] + 0.5
|
| sampled_x = indices[1][index_sample] + 0.5
|
|
|
| _, _, w, h = bbox_xywh
|
| x = (sampled_x / w * 256.0).cpu().tolist()
|
| y = (sampled_y / h * 256.0).cpu().tolist()
|
|
|
| annotation[DensePoseDataRelative.X_KEY].extend(x)
|
| annotation[DensePoseDataRelative.Y_KEY].extend(y)
|
| annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
|
| return annotation
|
|
|
| def _produce_mask_and_results(
|
| self, instance: Instances, bbox_xywh: IntTupleBox
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """
|
| Method to get labels and DensePose results from an instance
|
|
|
| Args:
|
| instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
|
| bbox_xywh (IntTupleBox): the corresponding bounding box
|
|
|
| Return:
|
| mask (torch.Tensor): shape [H, W], DensePose segmentation mask
|
| embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
|
| DensePose CSE Embeddings
|
| other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
|
| for potential other values
|
| """
|
| densepose_output = instance.pred_densepose
|
| S = densepose_output.coarse_segm
|
| E = densepose_output.embedding
|
| _, _, w, h = bbox_xywh
|
| embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
|
| coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
|
| mask = coarse_segm_resized.argmax(0) > 0
|
| other_values = torch.empty((0, h, w), device=E.device)
|
| return mask, embeddings, other_values
|
|
|
| def _resample_mask(self, output: Any) -> torch.Tensor:
|
| """
|
| Convert DensePose predictor output to segmentation annotation - tensors of size
|
| (256, 256) and type `int64`.
|
|
|
| Args:
|
| output: DensePose predictor output with the following attributes:
|
| - coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
|
| segmentation scores
|
| Return:
|
| Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
|
| where S = DensePoseDataRelative.MASK_SIZE
|
| """
|
| sz = DensePoseDataRelative.MASK_SIZE
|
| mask = (
|
| F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
|
| .argmax(dim=1)
|
| .long()
|
| .squeeze()
|
| .cpu()
|
| )
|
| return mask
|
|
|