| | |
| |
|
| | import random |
| | from typing import Tuple |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| | from detectron2.config import CfgNode |
| |
|
| | from densepose.structures.mesh import create_mesh |
| |
|
| | from .utils import sample_random_indices |
| |
|
| |
|
| | class ShapeToShapeCycleLoss(nn.Module): |
| | """ |
| | Cycle Loss for Shapes. |
| | Inspired by: |
| | "Mapping in a Cycle: Sinkhorn Regularized Unsupervised Learning for Point Cloud Shapes". |
| | """ |
| |
|
| | def __init__(self, cfg: CfgNode): |
| | super().__init__() |
| | self.shape_names = list(cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.keys()) |
| | self.all_shape_pairs = [ |
| | (x, y) for i, x in enumerate(self.shape_names) for y in self.shape_names[i + 1 :] |
| | ] |
| | random.shuffle(self.all_shape_pairs) |
| | self.cur_pos = 0 |
| | self.norm_p = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.NORM_P |
| | self.temperature = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.TEMPERATURE |
| | self.max_num_vertices = ( |
| | cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.MAX_NUM_VERTICES |
| | ) |
| |
|
| | def _sample_random_pair(self) -> Tuple[str, str]: |
| | """ |
| | Produce a random pair of different mesh names |
| | |
| | Return: |
| | tuple(str, str): a pair of different mesh names |
| | """ |
| | if self.cur_pos >= len(self.all_shape_pairs): |
| | random.shuffle(self.all_shape_pairs) |
| | self.cur_pos = 0 |
| | shape_pair = self.all_shape_pairs[self.cur_pos] |
| | self.cur_pos += 1 |
| | return shape_pair |
| |
|
| | def forward(self, embedder: nn.Module): |
| | """ |
| | Do a forward pass with a random pair (src, dst) pair of shapes |
| | Args: |
| | embedder (nn.Module): module that computes vertex embeddings for different meshes |
| | """ |
| | src_mesh_name, dst_mesh_name = self._sample_random_pair() |
| | return self._forward_one_pair(embedder, src_mesh_name, dst_mesh_name) |
| |
|
| | def fake_value(self, embedder: nn.Module): |
| | losses = [] |
| | for mesh_name in embedder.mesh_names: |
| | losses.append(embedder(mesh_name).sum() * 0) |
| | return torch.mean(torch.stack(losses)) |
| |
|
| | def _get_embeddings_and_geodists_for_mesh( |
| | self, embedder: nn.Module, mesh_name: str |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Produces embeddings and geodesic distance tensors for a given mesh. May subsample |
| | the mesh, if it contains too many vertices (controlled by |
| | SHAPE_CYCLE_LOSS_MAX_NUM_VERTICES parameter). |
| | Args: |
| | embedder (nn.Module): module that computes embeddings for mesh vertices |
| | mesh_name (str): mesh name |
| | Return: |
| | embeddings (torch.Tensor of size [N, D]): embeddings for selected mesh |
| | vertices (N = number of selected vertices, D = embedding space dim) |
| | geodists (torch.Tensor of size [N, N]): geodesic distances for the selected |
| | mesh vertices (N = number of selected vertices) |
| | """ |
| | embeddings = embedder(mesh_name) |
| | indices = sample_random_indices( |
| | embeddings.shape[0], self.max_num_vertices, embeddings.device |
| | ) |
| | mesh = create_mesh(mesh_name, embeddings.device) |
| | geodists = mesh.geodists |
| | if indices is not None: |
| | embeddings = embeddings[indices] |
| | geodists = geodists[torch.meshgrid(indices, indices)] |
| | return embeddings, geodists |
| |
|
| | def _forward_one_pair( |
| | self, embedder: nn.Module, mesh_name_1: str, mesh_name_2: str |
| | ) -> torch.Tensor: |
| | """ |
| | Do a forward pass with a selected pair of meshes |
| | Args: |
| | embedder (nn.Module): module that computes vertex embeddings for different meshes |
| | mesh_name_1 (str): first mesh name |
| | mesh_name_2 (str): second mesh name |
| | Return: |
| | Tensor containing the loss value |
| | """ |
| | embeddings_1, geodists_1 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_1) |
| | embeddings_2, geodists_2 = self._get_embeddings_and_geodists_for_mesh(embedder, mesh_name_2) |
| | sim_matrix_12 = embeddings_1.mm(embeddings_2.T) |
| |
|
| | c_12 = F.softmax(sim_matrix_12 / self.temperature, dim=1) |
| | c_21 = F.softmax(sim_matrix_12.T / self.temperature, dim=1) |
| | c_11 = c_12.mm(c_21) |
| | c_22 = c_21.mm(c_12) |
| |
|
| | loss_cycle_11 = torch.norm(geodists_1 * c_11, p=self.norm_p) |
| | loss_cycle_22 = torch.norm(geodists_2 * c_22, p=self.norm_p) |
| |
|
| | return loss_cycle_11 + loss_cycle_22 |
| |
|