| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume |
| |
|
| | def safe_l2_normalize(x, dim=None, eps=1e-6): |
| | return F.normalize(x, p=2, dim=dim, eps=eps) |
| |
|
| | class Projector(): |
| | """ |
| | Obtain features from geometryVolume and rendering_feature_maps for generalized rendering |
| | """ |
| |
|
| | def compute_angle(self, xyz, query_c2w, supporting_c2ws): |
| | """ |
| | |
| | :param xyz: [N_rays, n_samples,3 ] |
| | :param query_c2w: [1,4,4] |
| | :param supporting_c2ws: [n,4,4] |
| | :return: |
| | """ |
| | N_rays, n_samples, _ = xyz.shape |
| | num_views = supporting_c2ws.shape[0] |
| | xyz = xyz.reshape(-1, 3) |
| |
|
| | ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| | ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) |
| | ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| | ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
| | ray_diff = ray2tar_pose - ray2support_pose |
| | ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
| | ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
| | ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
| | ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
| | ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
| | return ray_diff.detach() |
| |
|
| |
|
| | def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws): |
| | """ |
| | |
| | :param xyz: [N_rays, n_samples,3 ] |
| | :param surface_normals: [N_rays, n_samples,3 ] |
| | :param supporting_c2ws: [n,4,4] |
| | :return: |
| | """ |
| | N_rays, n_samples, _ = xyz.shape |
| | num_views = supporting_c2ws.shape[0] |
| | xyz = xyz.reshape(-1, 3) |
| |
|
| | ray2tar_pose = surface_normals |
| | ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| | ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
| | ray_diff = ray2tar_pose - ray2support_pose |
| | ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
| | ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
| | ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
| | ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
| | ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
| | |
| | return ray_diff.detach() |
| |
|
| | @torch.no_grad() |
| | def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values): |
| | """ |
| | compute the depth difference of query pts projected on the image and the predicted depth values of the image |
| | :param xyz: [N_rays, n_samples,3 ] |
| | :param w2cs: [N_views, 4, 4] |
| | :param intrinsics: [N_views, 3, 3] |
| | :param pred_depth_values: [N_views, N_rays, n_samples,1 ] |
| | :param pred_depth_masks: [N_views, N_rays, n_samples] |
| | :return: |
| | """ |
| | device = xyz.device |
| | N_views = w2cs.shape[0] |
| | N_rays, n_samples, _ = xyz.shape |
| | proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) |
| |
|
| | proj_rot = proj_matrix[:, :3, :3] |
| | proj_trans = proj_matrix[:, :3, 3:] |
| |
|
| | batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1) |
| |
|
| | proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans |
| |
|
| | |
| | |
| | Z = proj_xyz[:, 2].clamp(min=1e-3) |
| | proj_z = Z.view(N_views, N_rays, n_samples, 1) |
| |
|
| | z_diff = proj_z - pred_depth_values |
| |
|
| | return z_diff |
| |
|
| | def compute(self, |
| | pts, |
| | |
| | geometryVolume=None, |
| | geometryVolumeMask=None, |
| | vol_dims=None, |
| | partial_vol_origin=None, |
| | vol_size=None, |
| | |
| | rendering_feature_maps=None, |
| | color_maps=None, |
| | w2cs=None, |
| | intrinsics=None, |
| | img_wh=None, |
| | query_img_idx=0, |
| | query_c2w=None, |
| | pred_depth_maps=None, |
| | pred_depth_masks=None |
| | ): |
| | """ |
| | extract features of pts for rendering |
| | :param pts: |
| | :param geometryVolume: |
| | :param vol_dims: |
| | :param partial_vol_origin: |
| | :param vol_size: |
| | :param rendering_feature_maps: |
| | :param color_maps: |
| | :param w2cs: |
| | :param intrinsics: |
| | :param img_wh: |
| | :param rendering_img_idx: by default, we render the first view of w2cs |
| | :return: |
| | """ |
| | device = pts.device |
| | c2ws = torch.inverse(w2cs) |
| |
|
| | if len(pts.shape) == 2: |
| | pts = pts[None, :, :] |
| |
|
| | N_rays, n_samples, _ = pts.shape |
| | N_views = rendering_feature_maps.shape[0] |
| |
|
| | supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
| | query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
| |
|
| | if query_c2w is None and query_img_idx > -1: |
| | query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
| | supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
| | supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
| | supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
| | supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
| | supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
| |
|
| | if pred_depth_maps is not None: |
| | supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
| | supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
| | |
| | N_supporting_views = N_views - 1 |
| | else: |
| | supporting_c2ws = c2ws |
| | supporting_w2cs = w2cs |
| | supporting_rendering_feature_maps = rendering_feature_maps |
| | supporting_color_maps = color_maps |
| | supporting_intrinsics = intrinsics |
| | supporting_depth_maps = pred_depth_masks |
| | supporting_depth_masks = pred_depth_masks |
| | |
| | N_supporting_views = N_views |
| | |
| | if geometryVolume is not None: |
| | |
| | pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
| | pts, geometryVolume, vol_dims, |
| | partial_vol_origin, vol_size) |
| |
|
| | if len(geometryVolumeMask.shape) == 3: |
| | geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
| |
|
| | pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
| | pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
| | partial_vol_origin, vol_size) |
| |
|
| | pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
| | else: |
| | pts_geometry_feature = None |
| | pts_geometry_masks = None |
| |
|
| | |
| | pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
| | pts, supporting_rendering_feature_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh, |
| | return_mask=True) |
| | |
| | |
| | pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
| |
|
| | pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | |
| | pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
| |
|
| | rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
| |
|
| |
|
| | ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) |
| | |
| | if pts_geometry_masks is not None: |
| | final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
| | pts_rendering_mask |
| | else: |
| | final_mask = pts_rendering_mask |
| | |
| | z_diff, pts_pred_depth_masks = None, None |
| | |
| | if pred_depth_maps is not None: |
| | pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
| | 1).contiguous() |
| |
|
| | |
| | |
| | pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
| | supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | |
| | pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
| | 0] |
| |
|
| | z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
| | |
| | return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks |
| |
|
| |
|
| | def compute_view_independent( |
| | self, |
| | pts, |
| | |
| | geometryVolume=None, |
| | geometryVolumeMask=None, |
| | sdf_network=None, |
| | lod=0, |
| | vol_dims=None, |
| | partial_vol_origin=None, |
| | vol_size=None, |
| | |
| | rendering_feature_maps=None, |
| | color_maps=None, |
| | w2cs=None, |
| | target_candidate_w2cs=None, |
| | intrinsics=None, |
| | img_wh=None, |
| | query_img_idx=0, |
| | query_c2w=None, |
| | pred_depth_maps=None, |
| | pred_depth_masks=None |
| | ): |
| | """ |
| | extract features of pts for rendering |
| | :param pts: |
| | :param geometryVolume: |
| | :param vol_dims: |
| | :param partial_vol_origin: |
| | :param vol_size: |
| | :param rendering_feature_maps: |
| | :param color_maps: |
| | :param w2cs: |
| | :param intrinsics: |
| | :param img_wh: |
| | :param rendering_img_idx: by default, we render the first view of w2cs |
| | :return: |
| | """ |
| | device = pts.device |
| | c2ws = torch.inverse(w2cs) |
| |
|
| | if len(pts.shape) == 2: |
| | pts = pts[None, :, :] |
| |
|
| | N_rays, n_samples, _ = pts.shape |
| | N_views = rendering_feature_maps.shape[0] |
| |
|
| | supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
| | query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
| |
|
| | if query_c2w is None and query_img_idx > -1: |
| | query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
| | supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
| | supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
| | supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
| | supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
| | supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
| |
|
| | if pred_depth_maps is not None: |
| | supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
| | supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
| | |
| | N_supporting_views = N_views - 1 |
| | else: |
| | supporting_c2ws = c2ws |
| | supporting_w2cs = w2cs |
| | supporting_rendering_feature_maps = rendering_feature_maps |
| | supporting_color_maps = color_maps |
| | supporting_intrinsics = intrinsics |
| | supporting_depth_maps = pred_depth_masks |
| | supporting_depth_masks = pred_depth_masks |
| | |
| | N_supporting_views = N_views |
| | |
| | if geometryVolume is not None: |
| | |
| | pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
| | pts, geometryVolume, vol_dims, |
| | partial_vol_origin, vol_size) |
| |
|
| | if len(geometryVolumeMask.shape) == 3: |
| | geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
| |
|
| | pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
| | pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
| | partial_vol_origin, vol_size) |
| |
|
| | pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
| | else: |
| | pts_geometry_feature = None |
| | pts_geometry_masks = None |
| |
|
| | |
| | pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
| | pts, supporting_rendering_feature_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh, |
| | return_mask=True) |
| |
|
| | |
| | pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
| |
|
| | pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | |
| | pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
| |
|
| | rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
| | |
| | |
| | |
| | gradients = sdf_network.gradient( |
| | pts.reshape(-1, 3), |
| | geometryVolume.unsqueeze(0), |
| | lod=lod |
| | ).squeeze() |
| | |
| | surface_normals = safe_l2_normalize(gradients, dim=-1) |
| | |
| | ren_ray_diff = self.compute_angle_view_independent( |
| | xyz=pts, |
| | surface_normals=surface_normals, |
| | supporting_c2ws=supporting_c2ws |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if pts_geometry_masks is not None: |
| | final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
| | pts_rendering_mask |
| | else: |
| | final_mask = pts_rendering_mask |
| | |
| | z_diff, pts_pred_depth_masks = None, None |
| | |
| | if pred_depth_maps is not None: |
| | pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
| | 1).contiguous() |
| |
|
| | |
| | |
| | pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
| | supporting_w2cs, |
| | supporting_intrinsics, img_wh) |
| | |
| | pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
| | 0] |
| |
|
| | z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
| | |
| | return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks |
| |
|