| | |
| |
|
| | import json |
| | import logging |
| | from typing import List, Optional |
| | import torch |
| | from torch import nn |
| |
|
| | from detectron2.utils.file_io import PathManager |
| |
|
| | from densepose.structures.mesh import create_mesh |
| |
|
| |
|
| | class MeshAlignmentEvaluator: |
| | """ |
| | Class for evaluation of 3D mesh alignment based on the learned vertex embeddings |
| | """ |
| |
|
| | def __init__(self, embedder: nn.Module, mesh_names: Optional[List[str]]): |
| | self.embedder = embedder |
| | |
| | self.mesh_names = mesh_names if mesh_names else embedder.mesh_names |
| | self.logger = logging.getLogger(__name__) |
| | with PathManager.open( |
| | "https://dl.fbaipublicfiles.com/densepose/data/cse/mesh_keyvertices_v0.json", "r" |
| | ) as f: |
| | self.mesh_keyvertices = json.load(f) |
| |
|
| | def evaluate(self): |
| | ge_per_mesh = {} |
| | gps_per_mesh = {} |
| | for mesh_name_1 in self.mesh_names: |
| | avg_errors = [] |
| | avg_gps = [] |
| | embeddings_1 = self.embedder(mesh_name_1) |
| | keyvertices_1 = self.mesh_keyvertices[mesh_name_1] |
| | keyvertex_names_1 = list(keyvertices_1.keys()) |
| | keyvertex_indices_1 = [keyvertices_1[name] for name in keyvertex_names_1] |
| | for mesh_name_2 in self.mesh_names: |
| | if mesh_name_1 == mesh_name_2: |
| | continue |
| | embeddings_2 = self.embedder(mesh_name_2) |
| | keyvertices_2 = self.mesh_keyvertices[mesh_name_2] |
| | sim_matrix_12 = embeddings_1[keyvertex_indices_1].mm(embeddings_2.T) |
| | vertices_2_matching_keyvertices_1 = sim_matrix_12.argmax(axis=1) |
| | mesh_2 = create_mesh(mesh_name_2, embeddings_2.device) |
| | geodists = mesh_2.geodists[ |
| | vertices_2_matching_keyvertices_1, |
| | [keyvertices_2[name] for name in keyvertex_names_1], |
| | ] |
| | Current_Mean_Distances = 0.255 |
| | gps = (-(geodists**2) / (2 * (Current_Mean_Distances**2))).exp() |
| | avg_errors.append(geodists.mean().item()) |
| | avg_gps.append(gps.mean().item()) |
| |
|
| | ge_mean = torch.as_tensor(avg_errors).mean().item() |
| | gps_mean = torch.as_tensor(avg_gps).mean().item() |
| | ge_per_mesh[mesh_name_1] = ge_mean |
| | gps_per_mesh[mesh_name_1] = gps_mean |
| | ge_mean_global = torch.as_tensor(list(ge_per_mesh.values())).mean().item() |
| | gps_mean_global = torch.as_tensor(list(gps_per_mesh.values())).mean().item() |
| | per_mesh_metrics = { |
| | "GE": ge_per_mesh, |
| | "GPS": gps_per_mesh, |
| | } |
| | return ge_mean_global, gps_mean_global, per_mesh_metrics |
| |
|