| |
|
|
| import torch |
| import torch.nn.functional as F |
| from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume |
|
|
| def safe_l2_normalize(x, dim=None, eps=1e-6): |
| return F.normalize(x, p=2, dim=dim, eps=eps) |
|
|
| class Projector(): |
| """ |
| Obtain features from geometryVolume and rendering_feature_maps for generalized rendering |
| """ |
|
|
| def compute_angle(self, xyz, query_c2w, supporting_c2ws): |
| """ |
| |
| :param xyz: [N_rays, n_samples,3 ] |
| :param query_c2w: [1,4,4] |
| :param supporting_c2ws: [n,4,4] |
| :return: |
| """ |
| N_rays, n_samples, _ = xyz.shape |
| num_views = supporting_c2ws.shape[0] |
| xyz = xyz.reshape(-1, 3) |
|
|
| ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) |
| ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
| ray_diff = ray2tar_pose - ray2support_pose |
| ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
| ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
| ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
| ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
| ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
| return ray_diff.detach() |
|
|
|
|
| def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws): |
| """ |
| |
| :param xyz: [N_rays, n_samples,3 ] |
| :param surface_normals: [N_rays, n_samples,3 ] |
| :param supporting_c2ws: [n,4,4] |
| :return: |
| """ |
| N_rays, n_samples, _ = xyz.shape |
| num_views = supporting_c2ws.shape[0] |
| xyz = xyz.reshape(-1, 3) |
|
|
| ray2tar_pose = surface_normals |
| ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) |
| ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) |
| ray_diff = ray2tar_pose - ray2support_pose |
| ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) |
| ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) |
| ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) |
| ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) |
| ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) |
| |
| return ray_diff.detach() |
|
|
| @torch.no_grad() |
| def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values): |
| """ |
| compute the depth difference of query pts projected on the image and the predicted depth values of the image |
| :param xyz: [N_rays, n_samples,3 ] |
| :param w2cs: [N_views, 4, 4] |
| :param intrinsics: [N_views, 3, 3] |
| :param pred_depth_values: [N_views, N_rays, n_samples,1 ] |
| :param pred_depth_masks: [N_views, N_rays, n_samples] |
| :return: |
| """ |
| device = xyz.device |
| N_views = w2cs.shape[0] |
| N_rays, n_samples, _ = xyz.shape |
| proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) |
|
|
| proj_rot = proj_matrix[:, :3, :3] |
| proj_trans = proj_matrix[:, :3, 3:] |
|
|
| batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1) |
|
|
| proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans |
|
|
| |
| |
| Z = proj_xyz[:, 2].clamp(min=1e-3) |
| proj_z = Z.view(N_views, N_rays, n_samples, 1) |
|
|
| z_diff = proj_z - pred_depth_values |
|
|
| return z_diff |
|
|
| def compute(self, |
| pts, |
| |
| geometryVolume=None, |
| geometryVolumeMask=None, |
| vol_dims=None, |
| partial_vol_origin=None, |
| vol_size=None, |
| |
| rendering_feature_maps=None, |
| color_maps=None, |
| w2cs=None, |
| intrinsics=None, |
| img_wh=None, |
| query_img_idx=0, |
| query_c2w=None, |
| pred_depth_maps=None, |
| pred_depth_masks=None |
| ): |
| """ |
| extract features of pts for rendering |
| :param pts: |
| :param geometryVolume: |
| :param vol_dims: |
| :param partial_vol_origin: |
| :param vol_size: |
| :param rendering_feature_maps: |
| :param color_maps: |
| :param w2cs: |
| :param intrinsics: |
| :param img_wh: |
| :param rendering_img_idx: by default, we render the first view of w2cs |
| :return: |
| """ |
| device = pts.device |
| c2ws = torch.inverse(w2cs) |
|
|
| if len(pts.shape) == 2: |
| pts = pts[None, :, :] |
|
|
| N_rays, n_samples, _ = pts.shape |
| N_views = rendering_feature_maps.shape[0] |
|
|
| supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
| query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
|
|
| if query_c2w is None and query_img_idx > -1: |
| query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
| supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
| supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
| supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
| supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
| supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
|
|
| if pred_depth_maps is not None: |
| supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
| supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
| |
| N_supporting_views = N_views - 1 |
| else: |
| supporting_c2ws = c2ws |
| supporting_w2cs = w2cs |
| supporting_rendering_feature_maps = rendering_feature_maps |
| supporting_color_maps = color_maps |
| supporting_intrinsics = intrinsics |
| supporting_depth_maps = pred_depth_masks |
| supporting_depth_masks = pred_depth_masks |
| |
| N_supporting_views = N_views |
| |
| if geometryVolume is not None: |
| |
| pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
| pts, geometryVolume, vol_dims, |
| partial_vol_origin, vol_size) |
|
|
| if len(geometryVolumeMask.shape) == 3: |
| geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
|
|
| pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
| pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
| partial_vol_origin, vol_size) |
|
|
| pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
| else: |
| pts_geometry_feature = None |
| pts_geometry_masks = None |
|
|
| |
| pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
| pts, supporting_rendering_feature_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh, |
| return_mask=True) |
| |
| |
| pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
|
|
| pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| |
| pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
|
|
| rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
|
|
|
|
| ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) |
| |
| if pts_geometry_masks is not None: |
| final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
| pts_rendering_mask |
| else: |
| final_mask = pts_rendering_mask |
| |
| z_diff, pts_pred_depth_masks = None, None |
| |
| if pred_depth_maps is not None: |
| pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
| 1).contiguous() |
|
|
| |
| |
| pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
| supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| |
| pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
| 0] |
|
|
| z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
| |
| return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks |
|
|
|
|
| def compute_view_independent( |
| self, |
| pts, |
| |
| geometryVolume=None, |
| geometryVolumeMask=None, |
| sdf_network=None, |
| lod=0, |
| vol_dims=None, |
| partial_vol_origin=None, |
| vol_size=None, |
| |
| rendering_feature_maps=None, |
| color_maps=None, |
| w2cs=None, |
| target_candidate_w2cs=None, |
| intrinsics=None, |
| img_wh=None, |
| query_img_idx=0, |
| query_c2w=None, |
| pred_depth_maps=None, |
| pred_depth_masks=None |
| ): |
| """ |
| extract features of pts for rendering |
| :param pts: |
| :param geometryVolume: |
| :param vol_dims: |
| :param partial_vol_origin: |
| :param vol_size: |
| :param rendering_feature_maps: |
| :param color_maps: |
| :param w2cs: |
| :param intrinsics: |
| :param img_wh: |
| :param rendering_img_idx: by default, we render the first view of w2cs |
| :return: |
| """ |
| device = pts.device |
| c2ws = torch.inverse(w2cs) |
|
|
| if len(pts.shape) == 2: |
| pts = pts[None, :, :] |
|
|
| N_rays, n_samples, _ = pts.shape |
| N_views = rendering_feature_maps.shape[0] |
|
|
| supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) |
| query_img_idx = torch.LongTensor([query_img_idx]).to(device) |
|
|
| if query_c2w is None and query_img_idx > -1: |
| query_c2w = torch.index_select(c2ws, 0, query_img_idx) |
| supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) |
| supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) |
| supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) |
| supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) |
| supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) |
|
|
| if pred_depth_maps is not None: |
| supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) |
| supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) |
| |
| N_supporting_views = N_views - 1 |
| else: |
| supporting_c2ws = c2ws |
| supporting_w2cs = w2cs |
| supporting_rendering_feature_maps = rendering_feature_maps |
| supporting_color_maps = color_maps |
| supporting_intrinsics = intrinsics |
| supporting_depth_maps = pred_depth_masks |
| supporting_depth_masks = pred_depth_masks |
| |
| N_supporting_views = N_views |
| |
| if geometryVolume is not None: |
| |
| pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( |
| pts, geometryVolume, vol_dims, |
| partial_vol_origin, vol_size) |
|
|
| if len(geometryVolumeMask.shape) == 3: |
| geometryVolumeMask = geometryVolumeMask[None, :, :, :] |
|
|
| pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( |
| pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, |
| partial_vol_origin, vol_size) |
|
|
| pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) |
| else: |
| pts_geometry_feature = None |
| pts_geometry_masks = None |
|
|
| |
| pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( |
| pts, supporting_rendering_feature_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh, |
| return_mask=True) |
|
|
| |
| pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() |
|
|
| pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| |
| pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() |
|
|
| rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) |
| |
| |
| |
| gradients = sdf_network.gradient( |
| pts.reshape(-1, 3), |
| geometryVolume.unsqueeze(0), |
| lod=lod |
| ).squeeze() |
| |
| surface_normals = safe_l2_normalize(gradients, dim=-1) |
| |
| ren_ray_diff = self.compute_angle_view_independent( |
| xyz=pts, |
| surface_normals=surface_normals, |
| supporting_c2ws=supporting_c2ws |
| ) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| if pts_geometry_masks is not None: |
| final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ |
| pts_rendering_mask |
| else: |
| final_mask = pts_rendering_mask |
| |
| z_diff, pts_pred_depth_masks = None, None |
| |
| if pred_depth_maps is not None: |
| pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, |
| 1).contiguous() |
|
|
| |
| |
| pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), |
| supporting_w2cs, |
| supporting_intrinsics, img_wh) |
| |
| pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, |
| 0] |
|
|
| z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) |
| |
| return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks |
|
|