| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from lib.dataset.mesh_util import projection |
| | from lib.common.render import Render |
| | import numpy as np |
| | import torch |
| | import os.path as osp |
| | from torchvision.utils import make_grid |
| | from pytorch3d.io import IO |
| | from pytorch3d.ops import sample_points_from_meshes |
| | from pytorch3d.loss.point_mesh_distance import _PointFaceDistance |
| | from pytorch3d.structures import Pointclouds |
| | from PIL import Image |
| |
|
| |
|
| | def point_mesh_distance(meshes, pcls): |
| |
|
| | if len(meshes) != len(pcls): |
| | raise ValueError("meshes and pointclouds must be equal sized batches") |
| | N = len(meshes) |
| |
|
| | |
| | points = pcls.points_packed() |
| | points_first_idx = pcls.cloud_to_packed_first_idx() |
| | max_points = pcls.num_points_per_cloud().max().item() |
| |
|
| | |
| | verts_packed = meshes.verts_packed() |
| | faces_packed = meshes.faces_packed() |
| | tris = verts_packed[faces_packed] |
| | tris_first_idx = meshes.mesh_to_faces_packed_first_idx() |
| |
|
| | |
| | point_to_face = _PointFaceDistance.apply(points, points_first_idx, tris, |
| | tris_first_idx, max_points, 5e-3) |
| |
|
| | |
| | point_to_cloud_idx = pcls.packed_to_cloud_idx() |
| | num_points_per_cloud = pcls.num_points_per_cloud() |
| | weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) |
| | weights_p = 1.0 / weights_p.float() |
| | point_to_face = torch.sqrt(point_to_face) * weights_p |
| | point_dist = point_to_face.sum() / N |
| |
|
| | return point_dist |
| |
|
| |
|
| | class Evaluator: |
| |
|
| | def __init__(self, device): |
| |
|
| | self.render = Render(size=512, device=device) |
| | self.device = device |
| |
|
| | def set_mesh(self, result_dict): |
| |
|
| | for k, v in result_dict.items(): |
| | setattr(self, k, v) |
| |
|
| | self.verts_pr -= self.recon_size / 2.0 |
| | self.verts_pr /= self.recon_size / 2.0 |
| | self.verts_gt = projection(self.verts_gt, self.calib) |
| | self.verts_gt[:, 1] *= -1 |
| |
|
| | self.src_mesh = self.render.VF2Mesh(self.verts_pr, self.faces_pr) |
| | self.tgt_mesh = self.render.VF2Mesh(self.verts_gt, self.faces_gt) |
| |
|
| | def calculate_normal_consist(self, normal_path): |
| |
|
| | self.render.meshes = self.src_mesh |
| | src_normal_imgs = self.render.get_rgb_image(cam_ids=[ 0,1,2, 3], |
| | bg='black') |
| | self.render.meshes = self.tgt_mesh |
| | tgt_normal_imgs = self.render.get_rgb_image(cam_ids=[0,1,2, 3], |
| | bg='black') |
| | |
| | src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) |
| | tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) |
| | src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| | tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
| |
|
| | src_norm[src_norm == 0.0] = 1.0 |
| | tgt_norm[tgt_norm == 0.0] = 1.0 |
| |
|
| | src_normal_arr /= src_norm |
| | tgt_normal_arr /= tgt_norm |
| |
|
| | src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| | tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
| | error = (( |
| | (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 |
| | |
| |
|
| | normal_img = Image.fromarray( |
| | (torch.cat([src_normal_arr, tgt_normal_arr], dim=1).permute( |
| | 1, 2, 0).detach().cpu().numpy() * 255.0).astype(np.uint8)) |
| | normal_img.save(normal_path) |
| | |
| | error_list = [] |
| | if len(src_normal_imgs) > 4: |
| | for i in range(len(src_normal_imgs)): |
| | src_normal_arr = src_normal_imgs[i] |
| | tgt_normal_arr = tgt_normal_imgs[i] |
| |
|
| | src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| | tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
| |
|
| | src_norm[src_norm == 0.0] = 1.0 |
| | tgt_norm[tgt_norm == 0.0] = 1.0 |
| |
|
| | src_normal_arr /= src_norm |
| | tgt_normal_arr /= tgt_norm |
| |
|
| | src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| | tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
| |
|
| | error = ((src_normal_arr - tgt_normal_arr) ** 2).sum(dim=0).mean() * 4.0 |
| | error_list.append(error) |
| |
|
| | |
| | return error_list |
| | else: |
| | src_normal_arr = make_grid(torch.cat(src_normal_imgs, dim=0), nrow=4,padding=0) |
| | tgt_normal_arr = make_grid(torch.cat(tgt_normal_imgs, dim=0), nrow=4,padding=0) |
| | src_norm = torch.norm(src_normal_arr, dim=0, keepdim=True) |
| | tgt_norm = torch.norm(tgt_normal_arr, dim=0, keepdim=True) |
| |
|
| | src_norm[src_norm == 0.0] = 1.0 |
| | tgt_norm[tgt_norm == 0.0] = 1.0 |
| |
|
| | src_normal_arr /= src_norm |
| | tgt_normal_arr /= tgt_norm |
| |
|
| | |
| |
|
| | src_normal_arr = (src_normal_arr + 1.0) * 0.5 |
| | tgt_normal_arr = (tgt_normal_arr + 1.0) * 0.5 |
| |
|
| | error = (( |
| | (src_normal_arr - tgt_normal_arr)**2).sum(dim=0).mean()) * 4 |
| | |
| | return error |
| |
|
| |
|
| | def export_mesh(self, dir, name): |
| |
|
| | IO().save_mesh(self.src_mesh, osp.join(dir, f"{name}_src.obj")) |
| | IO().save_mesh(self.tgt_mesh, osp.join(dir, f"{name}_tgt.obj")) |
| |
|
| | def calculate_chamfer_p2s(self, num_samples=1000): |
| |
|
| | tgt_points = Pointclouds( |
| | sample_points_from_meshes(self.tgt_mesh, num_samples)) |
| | src_points = Pointclouds( |
| | sample_points_from_meshes(self.src_mesh, num_samples)) |
| | p2s_dist = point_mesh_distance(self.src_mesh, tgt_points) * 100.0 |
| | chamfer_dist = (point_mesh_distance(self.tgt_mesh, src_points) * 100.0 |
| | + p2s_dist) * 0.5 |
| |
|
| | return chamfer_dist, p2s_dist |
| |
|
| | def calc_acc(self, output, target, thres=0.5, use_sdf=False): |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | with torch.no_grad(): |
| | output = output.masked_fill(output < thres, 0.0) |
| | output = output.masked_fill(output > thres, 1.0) |
| |
|
| | if use_sdf: |
| | target = target.masked_fill(target < thres, 0.0) |
| | target = target.masked_fill(target > thres, 1.0) |
| |
|
| | acc = output.eq(target).float().mean() |
| |
|
| | |
| | output = output > thres |
| | target = target > thres |
| |
|
| | union = output | target |
| | inter = output & target |
| |
|
| | _max = torch.tensor(1.0).to(output.device) |
| |
|
| | union = max(union.sum().float(), _max) |
| | true_pos = max(inter.sum().float(), _max) |
| | vol_pred = max(output.sum().float(), _max) |
| | vol_gt = max(target.sum().float(), _max) |
| |
|
| | return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt |
| |
|