| |
|
|
| from typing import Any, Dict, List |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from detectron2.config import CfgNode |
| from detectron2.structures import Instances |
|
|
| from densepose.data.meshes.catalog import MeshCatalog |
| from densepose.modeling.cse.utils import normalize_embeddings, squared_euclidean_distance_matrix |
| from densepose.structures.mesh import create_mesh |
|
|
| from .embed_utils import PackedCseAnnotations |
| from .utils import BilinearInterpolationHelper |
|
|
|
|
| class SoftEmbeddingLoss: |
| """ |
| Computes losses for estimated embeddings given annotated vertices. |
| Instances in a minibatch that correspond to the same mesh are grouped |
| together. For each group, loss is computed as cross-entropy for |
| unnormalized scores given ground truth mesh vertex ids. |
| Scores are based on: |
| 1) squared distances between estimated vertex embeddings |
| and mesh vertex embeddings; |
| 2) geodesic distances between vertices of a mesh |
| """ |
|
|
| def __init__(self, cfg: CfgNode): |
| """ |
| Initialize embedding loss from config |
| """ |
| self.embdist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_DIST_GAUSS_SIGMA |
| self.geodist_gauss_sigma = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.GEODESIC_DIST_GAUSS_SIGMA |
|
|
| def __call__( |
| self, |
| proposals_with_gt: List[Instances], |
| densepose_predictor_outputs: Any, |
| packed_annotations: PackedCseAnnotations, |
| interpolator: BilinearInterpolationHelper, |
| embedder: nn.Module, |
| ) -> Dict[int, torch.Tensor]: |
| """ |
| Produces losses for estimated embeddings given annotated vertices. |
| Embeddings for all the vertices of a mesh are computed by the embedder. |
| Embeddings for observed pixels are estimated by a predictor. |
| Losses are computed as cross-entropy for unnormalized scores given |
| ground truth vertex IDs. |
| 1) squared distances between estimated vertex embeddings |
| and mesh vertex embeddings; |
| 2) geodesic distances between vertices of a mesh |
| |
| Args: |
| proposals_with_gt (list of Instances): detections with associated |
| ground truth data; each item corresponds to instances detected |
| on 1 image; the number of items corresponds to the number of |
| images in a batch |
| densepose_predictor_outputs: an object of a dataclass that contains predictor |
| outputs with estimated values; assumed to have the following attributes: |
| * embedding - embedding estimates, tensor of shape [N, D, S, S], where |
| N = number of instances (= sum N_i, where N_i is the number of |
| instances on image i) |
| D = embedding space dimensionality (MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE) |
| S = output size (width and height) |
| packed_annotations (PackedCseAnnotations): contains various data useful |
| for loss computation, each data is packed into a single tensor |
| interpolator (BilinearInterpolationHelper): bilinear interpolation helper |
| embedder (nn.Module): module that computes vertex embeddings for different meshes |
| Return: |
| dict(int -> tensor): losses for different mesh IDs |
| """ |
| losses = {} |
| for mesh_id_tensor in packed_annotations.vertex_mesh_ids_gt.unique(): |
| mesh_id = mesh_id_tensor.item() |
| mesh_name = MeshCatalog.get_mesh_name(mesh_id) |
| |
| |
| j_valid = interpolator.j_valid * ( |
| packed_annotations.vertex_mesh_ids_gt == mesh_id |
| ) |
| if not torch.any(j_valid): |
| continue |
| |
| |
| vertex_embeddings_i = normalize_embeddings( |
| interpolator.extract_at_points( |
| densepose_predictor_outputs.embedding, |
| slice_fine_segm=slice(None), |
| w_ylo_xlo=interpolator.w_ylo_xlo[:, None], |
| w_ylo_xhi=interpolator.w_ylo_xhi[:, None], |
| w_yhi_xlo=interpolator.w_yhi_xlo[:, None], |
| w_yhi_xhi=interpolator.w_yhi_xhi[:, None], |
| )[j_valid, :] |
| ) |
| |
| |
| vertex_indices_i = packed_annotations.vertex_ids_gt[j_valid] |
| |
| |
| mesh_vertex_embeddings = embedder(mesh_name) |
| |
| |
| mesh = create_mesh(mesh_name, mesh_vertex_embeddings.device) |
| geodist_softmax_values = F.softmax( |
| mesh.geodists[vertex_indices_i] / (-self.geodist_gauss_sigma), dim=1 |
| ) |
| |
| |
| embdist_logsoftmax_values = F.log_softmax( |
| squared_euclidean_distance_matrix(vertex_embeddings_i, mesh_vertex_embeddings) |
| / (-self.embdist_gauss_sigma), |
| dim=1, |
| ) |
| losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean() |
|
|
| for mesh_name in embedder.mesh_names: |
| if mesh_name not in losses: |
| losses[mesh_name] = self.fake_value( |
| densepose_predictor_outputs, embedder, mesh_name |
| ) |
| return losses |
|
|
| def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module): |
| losses = {} |
| for mesh_name in embedder.mesh_names: |
| losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name) |
| return losses |
|
|
| def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module, mesh_name: str): |
| return densepose_predictor_outputs.embedding.sum() * 0 + embedder(mesh_name).sum() * 0 |
|
|