| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # pyre-unsafe | |
| import torch | |
| import torch.nn.functional as Fu | |
| from pytorch3d.ops import wmean | |
| from pytorch3d.renderer.cameras import CamerasBase | |
| from pytorch3d.structures import Pointclouds | |
| def cleanup_eval_depth( | |
| point_cloud: Pointclouds, | |
| camera: CamerasBase, | |
| depth: torch.Tensor, | |
| mask: torch.Tensor, | |
| sigma: float = 0.01, | |
| image=None, | |
| ): | |
| ba, _, H, W = depth.shape | |
| pcl = point_cloud.points_padded() | |
| n_pts = point_cloud.num_points_per_cloud() | |
| pcl_mask = ( | |
| torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None] | |
| < n_pts[:, None] | |
| ).type_as(pcl) | |
| pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1] | |
| pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1] | |
| depth_and_idx = torch.cat( | |
| ( | |
| depth, | |
| torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth), | |
| ), | |
| dim=1, | |
| ) | |
| depth_and_idx_sampled = Fu.grid_sample( | |
| depth_and_idx, -pcl_proj[:, None], mode="nearest" | |
| )[:, :, 0].view(ba, 2, -1) | |
| depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1) | |
| df = (depth_sampled[:, 0] - pcl_depth).abs() | |
| # the threshold is a sigma-multiple of the standard deviation of the depth | |
| mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1) | |
| std = ( | |
| # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. | |
| wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1)) | |
| .clamp(1e-4) | |
| .sqrt() | |
| .view(ba, -1) | |
| ) | |
| good_df_thr = std * sigma | |
| good_depth = (df <= good_df_thr).float() * pcl_mask | |
| # perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1) | |
| # print(f'Kept {100.0 * perc_kept.mean():1.3f} % points') | |
| good_depth_raster = torch.zeros_like(depth).view(ba, -1) | |
| good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth) | |
| good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float() | |
| # if float(torch.rand(1)) > 0.95: | |
| # depth_ok = depth * good_depth_mask | |
| # # visualize | |
| # visdom_env = 'depth_cleanup_dbg' | |
| # from visdom import Visdom | |
| # # from tools.vis_utils import make_depth_image | |
| # from pytorch3d.vis.plotly_vis import plot_scene | |
| # viz = Visdom() | |
| # show_pcls = { | |
| # 'pointclouds': point_cloud, | |
| # } | |
| # for d, nm in zip( | |
| # (depth, depth_ok), | |
| # ('pointclouds_unproj', 'pointclouds_unproj_ok'), | |
| # ): | |
| # pointclouds_unproj = get_rgbd_point_cloud( | |
| # camera, image, d, | |
| # ) | |
| # if int(pointclouds_unproj.num_points_per_cloud()) > 0: | |
| # show_pcls[nm] = pointclouds_unproj | |
| # scene_dict = {'1': { | |
| # **show_pcls, | |
| # 'cameras': camera, | |
| # }} | |
| # scene = plot_scene( | |
| # scene_dict, | |
| # pointcloud_max_points=5000, | |
| # pointcloud_marker_size=1.5, | |
| # camera_scale=1.0, | |
| # ) | |
| # viz.plotlyplot(scene, env=visdom_env, win='scene') | |
| # # depth_image_ok = make_depth_image(depths_ok, masks) | |
| # # viz.images(depth_image_ok, env=visdom_env, win='depth_ok') | |
| # # depth_image = make_depth_image(depths, masks) | |
| # # viz.images(depth_image, env=visdom_env, win='depth') | |
| # # # viz.images(rgb_rendered, env=visdom_env, win='images_render') | |
| # # viz.images(images, env=visdom_env, win='images') | |
| # import pdb; pdb.set_trace() | |
| return good_depth_mask | |