| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from lib.dataset.mesh_util import projection |
| from lib.common.render import Render |
| import numpy as np |
| import torch |
| import os.path as osp |
| from torchvision.utils import make_grid |
| from pytorch3d.io import IO |
| from pytorch3d.ops import sample_points_from_meshes |
| from pytorch3d.loss.point_mesh_distance import _PointFaceDistance |
| from pytorch3d.structures import Pointclouds |
| from PIL import Image |
|
|
|
|
| def point_mesh_distance(meshes, pcls): |
|
|
| if len(meshes) != len(pcls): |
| raise ValueError("meshes and pointclouds must be equal sized batches") |
| N = len(meshes) |
|
|
| |
| points = pcls.points_packed() |
| points_first_idx = pcls.cloud_to_packed_first_idx() |
| max_points = pcls.num_points_per_cloud().max().item() |
|
|
| |
| verts_packed = meshes.verts_packed() |
| faces_packed = meshes.faces_packed() |
| tris = verts_packed[faces_packed] |
| tris_first_idx = meshes.mesh_to_faces_packed_first_idx() |
|
|
| |
| point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, |
| tris_first_idx, max_points, 5e-3) |
|
|
| |
| point_to_cloud_idx = pcls.packed_to_cloud_idx() |
| num_points_per_cloud = pcls.num_points_per_cloud() |
| weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) |
| weights_p = 1.0 / weights_p.float() |
| point_to_face = torch.sqrt(point_to_face) * weights_p |
| point_dist = point_to_face.sum() / N |
|
|
| return point_dist |
|
|
|
|
| class Evaluator: |
|
|
| def __init__(self, device): |
|
|
| self.render = Render(size=512, device=device) |
| self.device = device |
|
|
| def set_mesh(self, result_dict): |
|
|
| for k, v in result_dict.items(): |
| setattr(self, k, v) |
|
|
| self.verts_pr -= self.recon_size / 2.0 |
| self.verts_pr /= self.recon_size / 2.0 |
| self.verts_gt = projection(self.verts_gt, self.calib) |
| self.verts_gt[:, 1] *= -1 |
|
|
| self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) |
| self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) |
|
|
| def calculate_normal_consist(self, normal_path): |
|
|
| self.render.meshes = self.src_mesh |
| src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], |
| bg='black') |
| self.render.meshes = self.tgt_mesh |
| tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], |
| bg='black') |
| |
| src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) |
| tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) |
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
|
|
| src_norm[src_norm == 0.0] = 1.0 |
| tgt_norm[tgt_norm == 0.0] = 1.0 |
|
|
| src_normal_arr /= src_norm |
| tgt_normal_arr /= tgt_norm |
|
|
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
| error = (( |
| (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 |
| |
|
|
| normal_img = Image.fromarray( |
| (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( |
| 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) |
| normal_img.save(normal_path) |
| |
| error_list = [] |
| if len(src_normal_imgs) > 4: |
| for i in range(len(src_normal_imgs)): |
| src_normal_arr = src_normal_imgs[i] |
| tgt_normal_arr = tgt_normal_imgs[i] |
|
|
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
|
|
| src_norm[src_norm == 0.0] = 1.0 |
| tgt_norm[tgt_norm == 0.0] = 1.0 |
|
|
| src_normal_arr /= src_norm |
| tgt_normal_arr /= tgt_norm |
|
|
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
|
|
| error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 |
| error_list.append(error) |
|
|
| |
| return error_list |
| else: |
| src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) |
| tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) |
| src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
|
|
| src_norm[src_norm == 0.0] = 1.0 |
| tgt_norm[tgt_norm == 0.0] = 1.0 |
|
|
| src_normal_arr /= src_norm |
| tgt_normal_arr /= tgt_norm |
|
|
| |
|
|
| src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
|
|
| error = (( |
| (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 |
| |
| return error |
|
|
|
|
| def export_mesh(self, dir, name): |
|
|
| IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) |
| IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) |
|
|
| def calculate_chamfer_p2s(self, num_samples=1000): |
|
|
| tgt_points = Pointclouds( |
| sample_points_from_meshes(self.tgt_mesh, num_samples)) |
| src_points = Pointclouds( |
| sample_points_from_meshes(self.src_mesh, num_samples)) |
| p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 |
| chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 |
| + p2s_dist) * 0.5 |
|
|
| return chamfer_dist, p2s_dist |
|
|
| def calc_acc(self, output, target, thres=0.5, use_sdf=False): |
|
|
| |
| |
| |
| |
|
|
| with torch.no_grad(): |
| output = output.masked_fill(output < thres, 0.0) |
| output = output.masked_fill(output > thres, 1.0) |
|
|
| if use_sdf: |
| target = target.masked_fill(target < thres, 0.0) |
| target = target.masked_fill(target > thres, 1.0) |
|
|
| acc = output.eq(target).float().mean() |
|
|
| |
| output = output > thres |
| target = target > thres |
|
|
| union = output | target |
| inter = output & target |
|
|
| _max = torch.tensor(1.0).to(output.device) |
|
|
| union = max(union.sum().float(), _max) |
| true_pos = max(inter.sum().float(), _max) |
| vol_pred = max(output.sum().float(), _max) |
| vol_gt = max(target.sum().float(), _max) |
|
|
| return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt |
|
|