fastfit / parse_utils /densepose /data /samplers /densepose_cse_base.py
tigger13's picture
Upload 452 files
2711c5f verified
# Copyright (c) Facebook, Inc. and its affiliates.
# pyre-unsafe
from typing import Any, Dict, List, Tuple
import torch
from torch.nn import functional as F
from ....detectron2.config import CfgNode
from ....detectron2.structures import Instances
from ...converters.base import IntTupleBox
from ...data.utils import get_class_to_mesh_name_mapping
from ...modeling.cse.utils import squared_euclidean_distance_matrix
from ...structures import DensePoseDataRelative
from .densepose_base import DensePoseBaseSampler
class DensePoseCSEBaseSampler(DensePoseBaseSampler):
"""
Base DensePose sampler to produce DensePose data from parse_utils.densepose predictions.
Samples for each class are drawn according to some distribution over all pixels estimated
to belong to that class.
"""
def __init__(
self,
cfg: CfgNode,
use_gt_categories: bool,
embedder: torch.nn.Module,
count_per_class: int = 8,
):
"""
Constructor
Args:
cfg (CfgNode): the config of the model
embedder (torch.nn.Module): necessary to compute mesh vertex embeddings
count_per_class (int): the sampler produces at most `count_per_class`
samples for each category
"""
super().__init__(count_per_class)
self.embedder = embedder
self.class_to_mesh_name = get_class_to_mesh_name_mapping(cfg)
self.use_gt_categories = use_gt_categories
def _sample(self, instance: Instances, bbox_xywh: IntTupleBox) -> Dict[str, List[Any]]:
"""
Sample DensPoseDataRelative from estimation results
"""
if self.use_gt_categories:
instance_class = instance.dataset_classes.tolist()[0]
else:
instance_class = instance.pred_classes.tolist()[0]
mesh_name = self.class_to_mesh_name[instance_class]
annotation = {
DensePoseDataRelative.X_KEY: [],
DensePoseDataRelative.Y_KEY: [],
DensePoseDataRelative.VERTEX_IDS_KEY: [],
DensePoseDataRelative.MESH_NAME_KEY: mesh_name,
}
mask, embeddings, other_values = self._produce_mask_and_results(instance, bbox_xywh)
indices = torch.nonzero(mask, as_tuple=True)
selected_embeddings = embeddings.permute(1, 2, 0)[indices].cpu()
values = other_values[:, indices[0], indices[1]]
k = values.shape[1]
count = min(self.count_per_class, k)
if count <= 0:
return annotation
index_sample = self._produce_index_sample(values, count)
closest_vertices = squared_euclidean_distance_matrix(
selected_embeddings[index_sample], self.embedder(mesh_name)
)
closest_vertices = torch.argmin(closest_vertices, dim=1)
sampled_y = indices[0][index_sample] + 0.5
sampled_x = indices[1][index_sample] + 0.5
# prepare / normalize data
_, _, w, h = bbox_xywh
x = (sampled_x / w * 256.0).cpu().tolist()
y = (sampled_y / h * 256.0).cpu().tolist()
# extend annotations
annotation[DensePoseDataRelative.X_KEY].extend(x)
annotation[DensePoseDataRelative.Y_KEY].extend(y)
annotation[DensePoseDataRelative.VERTEX_IDS_KEY].extend(closest_vertices.cpu().tolist())
return annotation
def _produce_mask_and_results(
self, instance: Instances, bbox_xywh: IntTupleBox
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Method to get labels and DensePose results from an instance
Args:
instance (Instances): an instance of `DensePoseEmbeddingPredictorOutput`
bbox_xywh (IntTupleBox): the corresponding bounding box
Return:
mask (torch.Tensor): shape [H, W], DensePose segmentation mask
embeddings (Tuple[torch.Tensor]): a tensor of shape [D, H, W],
DensePose CSE Embeddings
other_values (Tuple[torch.Tensor]): a tensor of shape [0, H, W],
for potential other values
"""
densepose_output = instance.pred_densepose
S = densepose_output.coarse_segm
E = densepose_output.embedding
_, _, w, h = bbox_xywh
embeddings = F.interpolate(E, size=(h, w), mode="bilinear")[0]
coarse_segm_resized = F.interpolate(S, size=(h, w), mode="bilinear")[0]
mask = coarse_segm_resized.argmax(0) > 0
other_values = torch.empty((0, h, w), device=E.device)
return mask, embeddings, other_values
def _resample_mask(self, output: Any) -> torch.Tensor:
"""
Convert DensePose predictor output to segmentation annotation - tensors of size
(256, 256) and type `int64`.
Args:
output: DensePose predictor output with the following attributes:
- coarse_segm: tensor of size [N, D, H, W] with unnormalized coarse
segmentation scores
Return:
Tensor of size (S, S) and type `int64` with coarse segmentation annotations,
where S = DensePoseDataRelative.MASK_SIZE
"""
sz = DensePoseDataRelative.MASK_SIZE
mask = (
F.interpolate(output.coarse_segm, (sz, sz), mode="bilinear", align_corners=False)
.argmax(dim=1)
.long()
.squeeze()
.cpu()
)
return mask