| """ |
| The codes are heavily borrowed from NeuS |
| """ |
|
|
| import os |
| import cv2 as cv |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import logging |
| import mcubes |
| from icecream import ic |
| from models.render_utils import sample_pdf |
|
|
| from models.projector import Projector |
| from tsparse.torchsparse_utils import sparse_to_dense_channel |
|
|
| from models.fast_renderer import FastRenderer |
|
|
| from models.patch_projector import PatchProjector |
|
|
|
|
| class SparseNeuSRenderer(nn.Module): |
| """ |
| conditional neus render; |
| optimize on normalized world space; |
| warped by nn.Module to support DataParallel traning |
| """ |
|
|
| def __init__(self, |
| rendering_network_outside, |
| sdf_network, |
| variance_network, |
| rendering_network, |
| n_samples, |
| n_importance, |
| n_outside, |
| perturb, |
| alpha_type='div', |
| conf=None |
| ): |
| super(SparseNeuSRenderer, self).__init__() |
|
|
| self.conf = conf |
| self.base_exp_dir = conf['general.base_exp_dir'] |
|
|
| |
| self.rendering_network_outside = rendering_network_outside |
| self.sdf_network = sdf_network |
| self.variance_network = variance_network |
| self.rendering_network = rendering_network |
|
|
| self.n_samples = n_samples |
| self.n_importance = n_importance |
| self.n_outside = n_outside |
| self.perturb = perturb |
| self.alpha_type = alpha_type |
|
|
| self.rendering_projector = Projector() |
|
|
| self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) |
| self.patch_projector = PatchProjector(self.h_patch_size) |
|
|
| self.ray_tracer = FastRenderer() |
|
|
| |
| try: |
| self.if_fitted_rendering = self.sdf_network.if_fitted_rendering |
| except: |
| self.if_fitted_rendering = False |
|
|
| def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, |
| conditional_valid_mask_volume=None): |
| device = rays_o.device |
| batch_size, n_samples = z_vals.shape |
| pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
| if conditional_valid_mask_volume is not None: |
| pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
| pts_mask = pts_mask.reshape(batch_size, n_samples) |
| pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] |
| else: |
| pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) |
|
|
| sdf = sdf.reshape(batch_size, n_samples) |
| prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] |
| prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] |
| mid_sdf = (prev_sdf + next_sdf) * 0.5 |
| dot_val = None |
| if self.alpha_type == 'uniform': |
| dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 |
| else: |
| dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) |
| prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) |
| dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) |
| dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) |
| dot_val = dot_val.clip(-10.0, 0.0) * pts_mask |
| dist = (next_z_vals - prev_z_vals) |
| prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 |
| next_esti_sdf = mid_sdf + dot_val * dist * 0.5 |
| prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) |
| next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) |
| alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) |
|
|
| alpha = alpha_sdf |
|
|
| |
| alpha = pts_mask * alpha |
|
|
| weights = alpha * torch.cumprod( |
| torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
|
|
| z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() |
| return z_samples |
|
|
| def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
| sdf_network, gru_fusion, |
| |
| conditional_volume=None, |
| conditional_valid_mask_volume=None |
| ): |
| device = rays_o.device |
| batch_size, n_samples = z_vals.shape |
| _, n_importance = new_z_vals.shape |
| pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] |
|
|
| if conditional_valid_mask_volume is not None: |
| pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
| pts_mask = pts_mask.reshape(batch_size, n_importance) |
| pts_mask_bool = (pts_mask > 0).view(-1) |
| else: |
| pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) |
|
|
| new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 |
|
|
| if torch.sum(pts_mask) > 1: |
| new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) |
| new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] |
|
|
| new_sdf = new_sdf.view(batch_size, n_importance) |
|
|
| z_vals = torch.cat([z_vals, new_z_vals], dim=-1) |
| sdf = torch.cat([sdf, new_sdf], dim=-1) |
|
|
| z_vals, index = torch.sort(z_vals, dim=-1) |
| xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) |
| index = index.reshape(-1) |
| sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) |
|
|
| return z_vals, sdf |
|
|
| @torch.no_grad() |
| def get_pts_mask_for_conditional_volume(self, pts, mask_volume): |
| """ |
| |
| :param pts: [N, 3] |
| :param mask_volume: [1, 1, X, Y, Z] |
| :return: |
| """ |
| num_pts = pts.shape[0] |
| pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
| pts = torch.flip(pts, dims=[-1]) |
|
|
| pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') |
| pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() |
|
|
| return pts_mask |
|
|
| def render_core(self, |
| rays_o, |
| rays_d, |
| z_vals, |
| sample_dist, |
| lod, |
| sdf_network, |
| rendering_network, |
| background_alpha=None, |
| background_sampled_color=None, |
| background_rgb=None, |
| alpha_inter_ratio=0.0, |
| |
| conditional_volume=None, |
| conditional_valid_mask_volume=None, |
| |
| feature_maps=None, |
| color_maps=None, |
| w2cs=None, |
| intrinsics=None, |
| img_wh=None, |
| query_c2w=None, |
| if_general_rendering=True, |
| if_render_with_grad=True, |
| |
| img_index=None, |
| rays_uv=None, |
| |
| bg_num=0 |
| ): |
| device = rays_o.device |
| N_rays = rays_o.shape[0] |
| _, n_samples = z_vals.shape |
| dists = z_vals[..., 1:] - z_vals[..., :-1] |
| dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) |
|
|
| mid_z_vals = z_vals + dists * 0.5 |
| mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] |
|
|
| pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
| dirs = rays_d[:, None, :].expand(pts.shape) |
|
|
| pts = pts.reshape(-1, 3) |
| dirs = dirs.reshape(-1, 3) |
|
|
| |
| if conditional_valid_mask_volume is not None: |
| pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) |
| pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() |
| pts_mask_bool = (pts_mask > 0).view(-1) |
|
|
| if torch.sum(pts_mask_bool.float()) < 1: |
| pts_mask_bool[:100] = True |
|
|
| else: |
| pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) |
| |
| |
| sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) |
|
|
| sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 |
| sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] |
| feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] |
| feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) |
| feature_vector[pts_mask_bool] = feature_vector_valid |
|
|
| |
| gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
| |
| gradients[pts_mask_bool] = sdf_network.gradient( |
| pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() |
|
|
| sampled_color_mlp = None |
| rendering_valid_mask_mlp = None |
| sampled_color_patch = None |
| rendering_patch_mask = None |
|
|
| if self.if_fitted_rendering: |
| position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] |
| sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
| sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) |
|
|
| |
| pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( |
| pts[pts_mask_bool][:, None, :], color_maps, intrinsics, |
| w2cs, img_wh=None) |
| pts_pixel_color = pts_pixel_color[:, 0, :, :] |
| pts_pixel_mask = pts_pixel_mask[:, 0, :] |
|
|
| |
| if_patch_blending = False if rays_uv is None else True |
| pts_patch_color, pts_patch_mask = None, None |
| if if_patch_blending: |
| pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( |
| pts.reshape([N_rays, n_samples, 3]), |
| rays_uv, gradients.reshape([N_rays, n_samples, 3]), |
| color_maps, |
| intrinsics[0], intrinsics, |
| query_c2w[0], torch.inverse(w2cs), img_wh=None |
| ) |
| N_src, Npx = pts_patch_mask.shape[2:] |
| pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] |
| pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] |
|
|
| sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) |
| sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) |
|
|
| sampled_color_mlp_, sampled_color_mlp_mask_, \ |
| sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( |
| pts[pts_mask_bool], |
| position_latent, |
| gradients[pts_mask_bool], |
| dirs[pts_mask_bool], |
| feature_vector[pts_mask_bool], |
| img_index=img_index, |
| pts_pixel_color=pts_pixel_color, |
| pts_pixel_mask=pts_pixel_mask, |
| pts_patch_color=pts_patch_color, |
| pts_patch_mask=pts_patch_mask |
|
|
| ) |
| sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ |
| sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() |
| sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) |
| sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) |
| rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 |
|
|
| |
| if if_patch_blending: |
| sampled_color_patch[pts_mask_bool] = sampled_color_patch_ |
| sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() |
| sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) |
| sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) |
| rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, |
| keepdim=True) > 0.5 |
| else: |
| sampled_color_patch, rendering_patch_mask = None, None |
|
|
| if if_general_rendering: |
| |
| ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute( |
| pts.view(N_rays, n_samples, 3), |
| |
| geometryVolume=conditional_volume[0], |
| geometryVolumeMask=conditional_valid_mask_volume[0], |
| |
| rendering_feature_maps=feature_maps, |
| color_maps=color_maps, |
| w2cs=w2cs, |
| intrinsics=intrinsics, |
| img_wh=img_wh, |
| query_img_idx=0, |
| query_c2w=query_c2w, |
| ) |
|
|
| |
| if if_render_with_grad: |
| |
| |
| sampled_color, rendering_valid_mask = rendering_network( |
| ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
| |
| else: |
| with torch.no_grad(): |
| sampled_color, rendering_valid_mask = rendering_network( |
| ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
| else: |
| sampled_color, rendering_valid_mask = None, None |
|
|
| inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) |
|
|
| true_dot_val = (dirs * gradients).sum(-1, keepdim=True) |
|
|
| iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( |
| -true_dot_val) * alpha_inter_ratio) |
|
|
| iter_cos = iter_cos * pts_mask.view(-1, 1) |
|
|
| true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
| true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
|
|
| prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) |
| next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) |
|
|
| p = prev_cdf - next_cdf |
| c = prev_cdf |
|
|
| if self.alpha_type == 'div': |
| alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) |
| elif self.alpha_type == 'uniform': |
| uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 |
| uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 |
| uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) |
| uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) |
| uniform_alpha = F.relu( |
| (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( |
| N_rays, n_samples).clip(0.0, 1.0) |
| alpha_sdf = uniform_alpha |
| else: |
| assert False |
|
|
| alpha = alpha_sdf |
|
|
| |
| alpha = alpha * pts_mask |
|
|
| |
| |
| |
| inside_sphere = pts_mask |
| relax_inside_sphere = pts_mask |
|
|
| weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, |
| :-1] |
| weights_sum = weights.sum(dim=-1, keepdim=True) |
| alpha_sum = alpha.sum(dim=-1, keepdim=True) |
|
|
| if bg_num > 0: |
| weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) |
| else: |
| weights_sum_fg = weights_sum |
|
|
| if sampled_color is not None: |
| color = (sampled_color * weights[:, :, None]).sum(dim=1) |
| else: |
| color = None |
| |
|
|
| if background_rgb is not None and color is not None: |
| color = color + background_rgb * (1.0 - weights_sum) |
| |
| |
| |
| |
|
|
|
|
| |
| color_mlp = None |
| |
| if sampled_color_mlp is not None: |
| color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) |
|
|
| if background_rgb is not None and color_mlp is not None: |
| color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) |
|
|
| |
| blended_color_patch = None |
| if sampled_color_patch is not None: |
| blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) |
|
|
| |
|
|
| gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, |
| dim=-1) - 1.0) ** 2 |
| |
| gradient_error = (pts_mask * gradient_error).sum() / ( |
| (pts_mask).sum() + 1e-5) |
|
|
| depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) |
| |
| |
| |
| |
| return { |
| 'color': color, |
| 'color_mask': rendering_valid_mask, |
| 'color_mlp': color_mlp, |
| 'color_mlp_mask': rendering_valid_mask_mlp, |
| 'sdf': sdf, |
| 'depth': depth, |
| 'dists': dists, |
| 'gradients': gradients.reshape(N_rays, n_samples, 3), |
| 'variance': 1.0 / inv_variance, |
| 'mid_z_vals': mid_z_vals, |
| 'weights': weights, |
| 'weights_sum': weights_sum, |
| 'alpha_sum': alpha_sum, |
| 'alpha_mean': alpha.mean(), |
| 'cdf': c.reshape(N_rays, n_samples), |
| 'gradient_error': gradient_error, |
| 'inside_sphere': inside_sphere, |
| 'blended_color_patch': blended_color_patch, |
| 'blended_color_patch_mask': rendering_patch_mask, |
| 'weights_sum_fg': weights_sum_fg |
| } |
|
|
| def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, |
| perturb_overwrite=-1, |
| background_rgb=None, |
| alpha_inter_ratio=0.0, |
| |
| lod=None, |
| conditional_volume=None, |
| conditional_valid_mask_volume=None, |
| |
| feature_maps=None, |
| color_maps=None, |
| w2cs=None, |
| intrinsics=None, |
| img_wh=None, |
| query_c2w=None, |
| if_general_rendering=True, |
| if_render_with_grad=True, |
| |
| img_index=None, |
| rays_uv=None, |
| |
| pre_sample=False, |
| |
| bg_ratio=0.0 |
| ): |
| device = rays_o.device |
| N_rays = len(rays_o) |
| |
| sample_dist = ((far - near) / self.n_samples).mean().item() |
| z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) |
| z_vals = near + (far - near) * z_vals[None, :] |
|
|
| bg_num = int(self.n_samples * bg_ratio) |
|
|
| if z_vals.shape[0] == 1: |
| z_vals = z_vals.repeat(N_rays, 1) |
|
|
| if bg_num > 0: |
| z_vals_bg = z_vals[:, self.n_samples - bg_num:] |
| z_vals = z_vals[:, :self.n_samples - bg_num] |
|
|
| n_samples = self.n_samples - bg_num |
| perturb = self.perturb |
|
|
| |
| if pre_sample: |
| z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, |
| conditional_valid_mask_volume) |
|
|
| if perturb_overwrite >= 0: |
| perturb = perturb_overwrite |
| if perturb > 0: |
| |
| mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
| upper = torch.cat([mids, z_vals[..., -1:]], -1) |
| lower = torch.cat([z_vals[..., :1], mids], -1) |
| |
| t_rand = torch.rand(z_vals.shape).to(device) |
| z_vals = lower + (upper - lower) * t_rand |
|
|
| background_alpha = None |
| background_sampled_color = None |
| z_val_before = z_vals.clone() |
| |
| if self.n_importance > 0: |
| with torch.no_grad(): |
| pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
| sdf_outputs = sdf_network.sdf( |
| pts.reshape(-1, 3), conditional_volume, lod=lod) |
| |
| sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) |
|
|
| n_steps = 4 |
| for i in range(n_steps): |
| new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, |
| 64 * 2 ** i, |
| conditional_valid_mask_volume=conditional_valid_mask_volume, |
| ) |
|
|
| |
| |
|
|
| z_vals, sdf = self.cat_z_vals( |
| rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
| sdf_network, gru_fusion=False, |
| conditional_volume=conditional_volume, |
| conditional_valid_mask_volume=conditional_valid_mask_volume, |
| ) |
|
|
| del sdf |
|
|
| n_samples = self.n_samples + self.n_importance |
|
|
| |
| ret_outside = None |
|
|
| |
| if bg_num > 0: |
| z_vals = torch.cat([z_vals, z_vals_bg], dim=1) |
| |
| |
| ret_fine = self.render_core(rays_o, |
| rays_d, |
| z_vals, |
| sample_dist, |
| lod, |
| sdf_network, |
| rendering_network, |
| background_rgb=background_rgb, |
| background_alpha=background_alpha, |
| background_sampled_color=background_sampled_color, |
| alpha_inter_ratio=alpha_inter_ratio, |
| |
| conditional_volume=conditional_volume, |
| conditional_valid_mask_volume=conditional_valid_mask_volume, |
| |
| feature_maps=feature_maps, |
| color_maps=color_maps, |
| w2cs=w2cs, |
| intrinsics=intrinsics, |
| img_wh=img_wh, |
| query_c2w=query_c2w, |
| if_general_rendering=if_general_rendering, |
| if_render_with_grad=if_render_with_grad, |
| |
| img_index=img_index, |
| rays_uv=rays_uv |
| ) |
|
|
| color_fine = ret_fine['color'] |
|
|
| if self.n_outside > 0: |
| color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) |
| else: |
| color_fine_mask = ret_fine['color_mask'] |
|
|
| weights = ret_fine['weights'] |
| weights_sum = ret_fine['weights_sum'] |
|
|
| gradients = ret_fine['gradients'] |
| mid_z_vals = ret_fine['mid_z_vals'] |
|
|
| |
| depth = ret_fine['depth'] |
| depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) |
| variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) |
|
|
| |
| pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 |
| sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] |
|
|
| result = { |
| 'depth': depth, |
| 'color_fine': color_fine, |
| 'color_fine_mask': color_fine_mask, |
| 'color_outside': ret_outside['color'] if ret_outside is not None else None, |
| 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, |
| 'color_mlp': ret_fine['color_mlp'], |
| 'color_mlp_mask': ret_fine['color_mlp_mask'], |
| 'variance': variance.mean(), |
| 'cdf_fine': ret_fine['cdf'], |
| 'depth_variance': depth_varaince, |
| 'weights_sum': weights_sum, |
| 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], |
| 'alpha_sum': ret_fine['alpha_sum'].mean(), |
| 'alpha_mean': ret_fine['alpha_mean'], |
| 'gradients': gradients, |
| 'weights': weights, |
| 'gradient_error_fine': ret_fine['gradient_error'], |
| 'inside_sphere': ret_fine['inside_sphere'], |
| 'sdf': ret_fine['sdf'], |
| 'sdf_random': sdf_random, |
| 'blended_color_patch': ret_fine['blended_color_patch'], |
| 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], |
| 'weights_sum_fg': ret_fine['weights_sum_fg'] |
| } |
|
|
| return result |
|
|
| @torch.no_grad() |
| def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): |
| |
| device = rays_o.device |
| N_rays = len(rays_o) |
| n_samples = self.n_samples * 2 |
|
|
| z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
| z_vals = near + (far - near) * z_vals[None, :] |
|
|
| if z_vals.shape[0] == 1: |
| z_vals = z_vals.repeat(N_rays, 1) |
|
|
| pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
|
|
| sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) |
|
|
| new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, |
| 200, |
| conditional_valid_mask_volume=mask_volume, |
| ) |
| return new_z_vals |
|
|
| @torch.no_grad() |
| def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): |
| device = rays_o.device |
| N_rays = len(rays_o) |
| n_samples = self.n_samples * 2 |
|
|
| z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
| z_vals = near + (far - near) * z_vals[None, :] |
|
|
| if z_vals.shape[0] == 1: |
| z_vals = z_vals.repeat(N_rays, 1) |
|
|
| mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 |
|
|
| pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
|
|
| pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( |
| [N_rays, n_samples - 1]) |
|
|
| |
| weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), |
| 0.1 * torch.ones_like(pts_mask).to(device)) |
|
|
| |
| z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() |
| return z_samples |
|
|
| @torch.no_grad() |
| def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, |
| partial_vol_origin, voxel_size, |
| near, far, depth_interval, d_plane_nums): |
| """ |
| Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) |
| :param coords: [n, 3] int coords |
| :param pred_depth_maps: [N_views, 1, h, w] |
| :param proj_matrices: [N_views, 4, 4] |
| :param partial_vol_origin: [3] |
| :param voxel_size: 1 |
| :param near: 1 |
| :param far: 1 |
| :param depth_interval: 1 |
| :param d_plane_nums: 1 |
| :return: |
| """ |
| device = pred_depth_maps.device |
| n_views, _, sizeH, sizeW = pred_depth_maps.shape |
|
|
| if len(partial_vol_origin.shape) == 1: |
| partial_vol_origin = partial_vol_origin[None, :] |
| pts = coords * voxel_size + partial_vol_origin |
|
|
| rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) |
| rs_grid = rs_grid.permute(0, 2, 1).contiguous() |
| nV = rs_grid.shape[-1] |
| rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) |
|
|
| |
| im_p = proj_matrices @ rs_grid |
| im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] |
| im_x = im_x / im_z |
| im_y = im_y / im_z |
|
|
| im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) |
|
|
| im_grid = im_grid.view(n_views, 1, -1, 2) |
| sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', |
| padding_mode='zeros', |
| align_corners=True)[:, 0, 0, :] |
| sampled_depths_valid = (sampled_depths > 0.5 * near).float() |
| valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), |
| far.item()) * sampled_depths_valid |
| valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), |
| far.item()) * sampled_depths_valid |
|
|
| mask = im_grid.abs() <= 1 |
| mask = mask[:, 0] |
| mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) |
|
|
| mask = mask.view(n_views, -1) |
| mask = mask.permute(1, 0).contiguous() |
|
|
| mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 |
|
|
| return mask_final |
|
|
| @torch.no_grad() |
| def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, |
| pred_depth_maps, proj_matrices, |
| partial_vol_origin, voxel_size, |
| near, far, depth_interval, d_plane_nums, |
| threshold=0.02, maximum_pts=110000): |
| """ |
| assume batch size == 1, from the first lod to get sparse voxels |
| :param sdf_volume: [1, X, Y, Z] |
| :param coords_volume: [3, X, Y, Z] |
| :param mask_volume: [1, X, Y, Z] |
| :param feature_volume: [C, X, Y, Z] |
| :param threshold: |
| :return: |
| """ |
| device = coords_volume.device |
| _, dX, dY, dZ = coords_volume.shape |
|
|
| def prune(sdf_pts, coords_pts, mask_volume, threshold): |
| occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) |
| valid_coords = coords_pts[occupancy_mask] |
|
|
| |
| mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, |
| partial_vol_origin, voxel_size, |
| near, far, depth_interval, d_plane_nums) |
| valid_coords = valid_coords[mask_filtered] |
|
|
| |
| occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) |
|
|
| |
| occupancy_mask = occupancy_mask.float() |
| occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
|
|
| final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
|
|
| return final_mask, torch.sum(final_mask.float()) |
|
|
| C, dX, dY, dZ = feature_volume.shape |
| sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
| mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
|
|
| |
| |
|
|
| final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
|
|
| while (valid_num > maximum_pts) and (threshold > 0.003): |
| threshold = threshold - 0.002 |
| final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
|
|
| valid_coords = coords_volume[final_mask] |
| valid_feature = feature_volume[final_mask] |
|
|
| valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
| valid_coords], dim=1) |
|
|
| |
| if valid_num > maximum_pts: |
| valid_num = valid_num.long() |
| occupancy = torch.ones([valid_num]).to(device) > 0 |
| choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
| replace=False) |
| ind = torch.nonzero(occupancy).to(device) |
| occupancy[ind[choice]] = False |
| valid_coords = valid_coords[occupancy] |
| valid_feature = valid_feature[occupancy] |
|
|
| print(threshold, "randomly sample to save memory") |
|
|
| return valid_coords, valid_feature |
|
|
| @torch.no_grad() |
| def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, |
| maximum_pts=110000): |
| """ |
| assume batch size == 1, from the first lod to get sparse voxels |
| :param sdf_volume: [num_pts, 1] |
| :param coords_volume: [3, X, Y, Z] |
| :param mask_volume: [1, X, Y, Z] |
| :param feature_volume: [C, X, Y, Z] |
| :param threshold: |
| :return: |
| """ |
|
|
| def prune(sdf_volume, mask_volume, threshold): |
| occupancy_mask = torch.abs(sdf_volume) < threshold |
|
|
| |
| occupancy_mask = occupancy_mask.float() |
| occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
|
|
| final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
|
|
| return final_mask, torch.sum(final_mask.float()) |
|
|
| C, dX, dY, dZ = feature_volume.shape |
| coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
| mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
|
|
| final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
|
|
| while (valid_num > maximum_pts) and (threshold > 0.003): |
| threshold = threshold - 0.002 |
| final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
|
|
| valid_coords = coords_volume[final_mask] |
| valid_feature = feature_volume[final_mask] |
|
|
| valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
| valid_coords], dim=1) |
|
|
| |
| if valid_num > maximum_pts: |
| device = sdf_volume.device |
| valid_num = valid_num.long() |
| occupancy = torch.ones([valid_num]).to(device) > 0 |
| choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
| replace=False) |
| ind = torch.nonzero(occupancy).to(device) |
| occupancy[ind[choice]] = False |
| valid_coords = valid_coords[occupancy] |
| valid_feature = valid_feature[occupancy] |
|
|
| print(threshold, "randomly sample to save memory") |
|
|
| return valid_coords, valid_feature |
|
|
| @torch.no_grad() |
| def extract_fields(self, bound_min, bound_max, resolution, query_func, device, |
| |
| **kwargs |
| ): |
| N = 64 |
| X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N) |
| Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N) |
| Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N) |
|
|
| u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
| with torch.no_grad(): |
| for xi, xs in enumerate(X): |
| for yi, ys in enumerate(Y): |
| for zi, zs in enumerate(Z): |
| xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") |
| pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) |
|
|
| |
| output = query_func(pts, **kwargs) |
| sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), |
| len(zs)).detach().cpu().numpy() |
|
|
| u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf |
| return u |
|
|
| @torch.no_grad() |
| def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, |
| |
| **kwargs |
| ): |
| |
|
|
| u = self.extract_fields(bound_min, bound_max, resolution, |
| lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), |
| |
| device, |
| |
| **kwargs |
| ) |
| if occupancy_mask is not None: |
| dX, dY, dZ = occupancy_mask.shape |
| empty_mask = 1 - occupancy_mask |
| empty_mask = empty_mask.view(1, 1, dX, dY, dZ) |
| |
| |
| empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') |
| empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 |
| u[empty_mask] = -100 |
| del empty_mask |
|
|
| vertices, triangles = mcubes.marching_cubes(u, threshold) |
| b_max_np = bound_max.detach().cpu().numpy() |
| b_min_np = bound_min.detach().cpu().numpy() |
|
|
| vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] |
| return vertices, triangles, u |
|
|
| @torch.no_grad() |
| def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): |
| """ |
| extract depth maps from the density volume |
| :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume |
| :param c2ws: [B, 4, 4] |
| :param H: |
| :param W: |
| :param near: |
| :param far: |
| :return: |
| """ |
| device = con_volume.device |
| batch_size = intrinsics.shape[0] |
|
|
| with torch.no_grad(): |
| ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), |
| torch.linspace(0, W - 1, W), indexing="ij") |
| p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) |
|
|
| intrinsics_inv = torch.inverse(intrinsics) |
|
|
| p = p.view(-1, 3).float().to(device) |
| p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() |
| rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
| rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() |
| rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) |
| rays_d = rays_v |
|
|
| rays_o = rays_o.contiguous().view(-1, 3) |
| rays_d = rays_d.contiguous().view(-1, 3) |
|
|
| |
| depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( |
| rays_o, rays_d, |
| near[None, :].repeat(rays_o.shape[0], 1), |
| far[None, :].repeat(rays_o.shape[0], 1), |
| sdf_network, con_volume |
| ) |
|
|
| depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) |
| depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) |
|
|
| depth_maps = torch.where(depth_masks, depth_maps, |
| torch.zeros_like(depth_masks.float()).to(device)) |
|
|
| return depth_maps, depth_masks |
|
|