| | """ |
| | The codes are heavily borrowed from NeuS |
| | """ |
| |
|
| | import os |
| | import cv2 as cv |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import logging |
| | import mcubes |
| | from icecream import ic |
| | from models.render_utils import sample_pdf |
| |
|
| | from models.projector import Projector |
| | from tsparse.torchsparse_utils import sparse_to_dense_channel |
| |
|
| | from models.fast_renderer import FastRenderer |
| |
|
| | from models.patch_projector import PatchProjector |
| |
|
| |
|
| | class SparseNeuSRenderer(nn.Module): |
| | """ |
| | conditional neus render; |
| | optimize on normalized world space; |
| | warped by nn.Module to support DataParallel traning |
| | """ |
| |
|
| | def __init__(self, |
| | rendering_network_outside, |
| | sdf_network, |
| | variance_network, |
| | rendering_network, |
| | n_samples, |
| | n_importance, |
| | n_outside, |
| | perturb, |
| | alpha_type='div', |
| | conf=None |
| | ): |
| | super(SparseNeuSRenderer, self).__init__() |
| |
|
| | self.conf = conf |
| | self.base_exp_dir = conf['general.base_exp_dir'] |
| |
|
| | |
| | self.rendering_network_outside = rendering_network_outside |
| | self.sdf_network = sdf_network |
| | self.variance_network = variance_network |
| | self.rendering_network = rendering_network |
| |
|
| | self.n_samples = n_samples |
| | self.n_importance = n_importance |
| | self.n_outside = n_outside |
| | self.perturb = perturb |
| | self.alpha_type = alpha_type |
| |
|
| | self.rendering_projector = Projector() |
| |
|
| | self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) |
| | self.patch_projector = PatchProjector(self.h_patch_size) |
| |
|
| | self.ray_tracer = FastRenderer() |
| |
|
| | |
| | try: |
| | self.if_fitted_rendering = self.sdf_network.if_fitted_rendering |
| | except: |
| | self.if_fitted_rendering = False |
| |
|
| | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, |
| | conditional_valid_mask_volume=None): |
| | device = rays_o.device |
| | batch_size, n_samples = z_vals.shape |
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
| |
|
| | if conditional_valid_mask_volume is not None: |
| | pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
| | pts_mask = pts_mask.reshape(batch_size, n_samples) |
| | pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] |
| | else: |
| | pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) |
| |
|
| | sdf = sdf.reshape(batch_size, n_samples) |
| | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] |
| | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] |
| | mid_sdf = (prev_sdf + next_sdf) * 0.5 |
| | dot_val = None |
| | if self.alpha_type == 'uniform': |
| | dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 |
| | else: |
| | dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) |
| | prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) |
| | dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) |
| | dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) |
| | dot_val = dot_val.clip(-10.0, 0.0) * pts_mask |
| | dist = (next_z_vals - prev_z_vals) |
| | prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 |
| | next_esti_sdf = mid_sdf + dot_val * dist * 0.5 |
| | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) |
| | next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) |
| | alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) |
| |
|
| | alpha = alpha_sdf |
| |
|
| | |
| | alpha = pts_mask * alpha |
| |
|
| | weights = alpha * torch.cumprod( |
| | torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] |
| |
|
| | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() |
| | return z_samples |
| |
|
| | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
| | sdf_network, gru_fusion, |
| | |
| | conditional_volume=None, |
| | conditional_valid_mask_volume=None |
| | ): |
| | device = rays_o.device |
| | batch_size, n_samples = z_vals.shape |
| | _, n_importance = new_z_vals.shape |
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] |
| |
|
| | if conditional_valid_mask_volume is not None: |
| | pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) |
| | pts_mask = pts_mask.reshape(batch_size, n_importance) |
| | pts_mask_bool = (pts_mask > 0).view(-1) |
| | else: |
| | pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) |
| |
|
| | new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 |
| |
|
| | if torch.sum(pts_mask) > 1: |
| | new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) |
| | new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] |
| |
|
| | new_sdf = new_sdf.view(batch_size, n_importance) |
| |
|
| | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) |
| | sdf = torch.cat([sdf, new_sdf], dim=-1) |
| |
|
| | z_vals, index = torch.sort(z_vals, dim=-1) |
| | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) |
| | index = index.reshape(-1) |
| | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) |
| |
|
| | return z_vals, sdf |
| |
|
| | @torch.no_grad() |
| | def get_pts_mask_for_conditional_volume(self, pts, mask_volume): |
| | """ |
| | |
| | :param pts: [N, 3] |
| | :param mask_volume: [1, 1, X, Y, Z] |
| | :return: |
| | """ |
| | num_pts = pts.shape[0] |
| | pts = pts.view(1, 1, 1, num_pts, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| |
|
| | pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') |
| | pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() |
| |
|
| | return pts_mask |
| |
|
| | def render_core(self, |
| | rays_o, |
| | rays_d, |
| | z_vals, |
| | sample_dist, |
| | lod, |
| | sdf_network, |
| | rendering_network, |
| | background_alpha=None, |
| | background_sampled_color=None, |
| | background_rgb=None, |
| | alpha_inter_ratio=0.0, |
| | |
| | conditional_volume=None, |
| | conditional_valid_mask_volume=None, |
| | |
| | feature_maps=None, |
| | color_maps=None, |
| | w2cs=None, |
| | intrinsics=None, |
| | img_wh=None, |
| | query_c2w=None, |
| | if_general_rendering=True, |
| | if_render_with_grad=True, |
| | |
| | img_index=None, |
| | rays_uv=None, |
| | |
| | bg_num=0 |
| | ): |
| | device = rays_o.device |
| | N_rays = rays_o.shape[0] |
| | _, n_samples = z_vals.shape |
| | dists = z_vals[..., 1:] - z_vals[..., :-1] |
| | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) |
| |
|
| | mid_z_vals = z_vals + dists * 0.5 |
| | mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] |
| |
|
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
| | dirs = rays_d[:, None, :].expand(pts.shape) |
| |
|
| | pts = pts.reshape(-1, 3) |
| | dirs = dirs.reshape(-1, 3) |
| |
|
| | |
| | if conditional_valid_mask_volume is not None: |
| | pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) |
| | pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() |
| | pts_mask_bool = (pts_mask > 0).view(-1) |
| |
|
| | if torch.sum(pts_mask_bool.float()) < 1: |
| | pts_mask_bool[:100] = True |
| |
|
| | else: |
| | pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) |
| | |
| | |
| | sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) |
| |
|
| | sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 |
| | sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] |
| | feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] |
| | feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) |
| | feature_vector[pts_mask_bool] = feature_vector_valid |
| |
|
| | |
| | gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
| | |
| | gradients[pts_mask_bool] = sdf_network.gradient( |
| | pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() |
| |
|
| | sampled_color_mlp = None |
| | rendering_valid_mask_mlp = None |
| | sampled_color_patch = None |
| | rendering_patch_mask = None |
| |
|
| | if self.if_fitted_rendering: |
| | position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] |
| | sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) |
| | sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) |
| |
|
| | |
| | pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( |
| | pts[pts_mask_bool][:, None, :], color_maps, intrinsics, |
| | w2cs, img_wh=None) |
| | pts_pixel_color = pts_pixel_color[:, 0, :, :] |
| | pts_pixel_mask = pts_pixel_mask[:, 0, :] |
| |
|
| | |
| | if_patch_blending = False if rays_uv is None else True |
| | pts_patch_color, pts_patch_mask = None, None |
| | if if_patch_blending: |
| | pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( |
| | pts.reshape([N_rays, n_samples, 3]), |
| | rays_uv, gradients.reshape([N_rays, n_samples, 3]), |
| | color_maps, |
| | intrinsics[0], intrinsics, |
| | query_c2w[0], torch.inverse(w2cs), img_wh=None |
| | ) |
| | N_src, Npx = pts_patch_mask.shape[2:] |
| | pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] |
| | pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] |
| |
|
| | sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) |
| | sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) |
| |
|
| | sampled_color_mlp_, sampled_color_mlp_mask_, \ |
| | sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( |
| | pts[pts_mask_bool], |
| | position_latent, |
| | gradients[pts_mask_bool], |
| | dirs[pts_mask_bool], |
| | feature_vector[pts_mask_bool], |
| | img_index=img_index, |
| | pts_pixel_color=pts_pixel_color, |
| | pts_pixel_mask=pts_pixel_mask, |
| | pts_patch_color=pts_patch_color, |
| | pts_patch_mask=pts_patch_mask |
| |
|
| | ) |
| | sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ |
| | sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() |
| | sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) |
| | sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) |
| | rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 |
| |
|
| | |
| | if if_patch_blending: |
| | sampled_color_patch[pts_mask_bool] = sampled_color_patch_ |
| | sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() |
| | sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) |
| | sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) |
| | rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, |
| | keepdim=True) > 0.5 |
| | else: |
| | sampled_color_patch, rendering_patch_mask = None, None |
| |
|
| | if if_general_rendering: |
| | |
| | ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute( |
| | pts.view(N_rays, n_samples, 3), |
| | |
| | geometryVolume=conditional_volume[0], |
| | geometryVolumeMask=conditional_valid_mask_volume[0], |
| | |
| | rendering_feature_maps=feature_maps, |
| | color_maps=color_maps, |
| | w2cs=w2cs, |
| | intrinsics=intrinsics, |
| | img_wh=img_wh, |
| | query_img_idx=0, |
| | query_c2w=query_c2w, |
| | ) |
| |
|
| | |
| | if if_render_with_grad: |
| | |
| | |
| | sampled_color, rendering_valid_mask = rendering_network( |
| | ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
| | |
| | else: |
| | with torch.no_grad(): |
| | sampled_color, rendering_valid_mask = rendering_network( |
| | ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) |
| | else: |
| | sampled_color, rendering_valid_mask = None, None |
| |
|
| | inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) |
| |
|
| | true_dot_val = (dirs * gradients).sum(-1, keepdim=True) |
| |
|
| | iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( |
| | -true_dot_val) * alpha_inter_ratio) |
| |
|
| | iter_cos = iter_cos * pts_mask.view(-1, 1) |
| |
|
| | true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
| | true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 |
| |
|
| | prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) |
| | next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) |
| |
|
| | p = prev_cdf - next_cdf |
| | c = prev_cdf |
| |
|
| | if self.alpha_type == 'div': |
| | alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) |
| | elif self.alpha_type == 'uniform': |
| | uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 |
| | uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 |
| | uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) |
| | uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) |
| | uniform_alpha = F.relu( |
| | (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( |
| | N_rays, n_samples).clip(0.0, 1.0) |
| | alpha_sdf = uniform_alpha |
| | else: |
| | assert False |
| |
|
| | alpha = alpha_sdf |
| |
|
| | |
| | alpha = alpha * pts_mask |
| |
|
| | |
| | |
| | |
| | inside_sphere = pts_mask |
| | relax_inside_sphere = pts_mask |
| |
|
| | weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, |
| | :-1] |
| | weights_sum = weights.sum(dim=-1, keepdim=True) |
| | alpha_sum = alpha.sum(dim=-1, keepdim=True) |
| |
|
| | if bg_num > 0: |
| | weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) |
| | else: |
| | weights_sum_fg = weights_sum |
| |
|
| | if sampled_color is not None: |
| | color = (sampled_color * weights[:, :, None]).sum(dim=1) |
| | else: |
| | color = None |
| | |
| |
|
| | if background_rgb is not None and color is not None: |
| | color = color + background_rgb * (1.0 - weights_sum) |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | color_mlp = None |
| | |
| | if sampled_color_mlp is not None: |
| | color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) |
| |
|
| | if background_rgb is not None and color_mlp is not None: |
| | color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) |
| |
|
| | |
| | blended_color_patch = None |
| | if sampled_color_patch is not None: |
| | blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) |
| |
|
| | |
| |
|
| | gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, |
| | dim=-1) - 1.0) ** 2 |
| | |
| | gradient_error = (pts_mask * gradient_error).sum() / ( |
| | (pts_mask).sum() + 1e-5) |
| |
|
| | depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) |
| | |
| | |
| | |
| | |
| | return { |
| | 'color': color, |
| | 'color_mask': rendering_valid_mask, |
| | 'color_mlp': color_mlp, |
| | 'color_mlp_mask': rendering_valid_mask_mlp, |
| | 'sdf': sdf, |
| | 'depth': depth, |
| | 'dists': dists, |
| | 'gradients': gradients.reshape(N_rays, n_samples, 3), |
| | 'variance': 1.0 / inv_variance, |
| | 'mid_z_vals': mid_z_vals, |
| | 'weights': weights, |
| | 'weights_sum': weights_sum, |
| | 'alpha_sum': alpha_sum, |
| | 'alpha_mean': alpha.mean(), |
| | 'cdf': c.reshape(N_rays, n_samples), |
| | 'gradient_error': gradient_error, |
| | 'inside_sphere': inside_sphere, |
| | 'blended_color_patch': blended_color_patch, |
| | 'blended_color_patch_mask': rendering_patch_mask, |
| | 'weights_sum_fg': weights_sum_fg |
| | } |
| |
|
| | def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, |
| | perturb_overwrite=-1, |
| | background_rgb=None, |
| | alpha_inter_ratio=0.0, |
| | |
| | lod=None, |
| | conditional_volume=None, |
| | conditional_valid_mask_volume=None, |
| | |
| | feature_maps=None, |
| | color_maps=None, |
| | w2cs=None, |
| | intrinsics=None, |
| | img_wh=None, |
| | query_c2w=None, |
| | if_general_rendering=True, |
| | if_render_with_grad=True, |
| | |
| | img_index=None, |
| | rays_uv=None, |
| | |
| | pre_sample=False, |
| | |
| | bg_ratio=0.0 |
| | ): |
| | device = rays_o.device |
| | N_rays = len(rays_o) |
| | |
| | sample_dist = ((far - near) / self.n_samples).mean().item() |
| | z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) |
| | z_vals = near + (far - near) * z_vals[None, :] |
| |
|
| | bg_num = int(self.n_samples * bg_ratio) |
| |
|
| | if z_vals.shape[0] == 1: |
| | z_vals = z_vals.repeat(N_rays, 1) |
| |
|
| | if bg_num > 0: |
| | z_vals_bg = z_vals[:, self.n_samples - bg_num:] |
| | z_vals = z_vals[:, :self.n_samples - bg_num] |
| |
|
| | n_samples = self.n_samples - bg_num |
| | perturb = self.perturb |
| |
|
| | |
| | if pre_sample: |
| | z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, |
| | conditional_valid_mask_volume) |
| |
|
| | if perturb_overwrite >= 0: |
| | perturb = perturb_overwrite |
| | if perturb > 0: |
| | |
| | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
| | upper = torch.cat([mids, z_vals[..., -1:]], -1) |
| | lower = torch.cat([z_vals[..., :1], mids], -1) |
| | |
| | t_rand = torch.rand(z_vals.shape).to(device) |
| | z_vals = lower + (upper - lower) * t_rand |
| |
|
| | background_alpha = None |
| | background_sampled_color = None |
| | z_val_before = z_vals.clone() |
| | |
| | if self.n_importance > 0: |
| | with torch.no_grad(): |
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
| |
|
| | sdf_outputs = sdf_network.sdf( |
| | pts.reshape(-1, 3), conditional_volume, lod=lod) |
| | |
| | sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) |
| |
|
| | n_steps = 4 |
| | for i in range(n_steps): |
| | new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, |
| | 64 * 2 ** i, |
| | conditional_valid_mask_volume=conditional_valid_mask_volume, |
| | ) |
| |
|
| | |
| | |
| |
|
| | z_vals, sdf = self.cat_z_vals( |
| | rays_o, rays_d, z_vals, new_z_vals, sdf, lod, |
| | sdf_network, gru_fusion=False, |
| | conditional_volume=conditional_volume, |
| | conditional_valid_mask_volume=conditional_valid_mask_volume, |
| | ) |
| |
|
| | del sdf |
| |
|
| | n_samples = self.n_samples + self.n_importance |
| |
|
| | |
| | ret_outside = None |
| |
|
| | |
| | if bg_num > 0: |
| | z_vals = torch.cat([z_vals, z_vals_bg], dim=1) |
| | |
| | |
| | ret_fine = self.render_core(rays_o, |
| | rays_d, |
| | z_vals, |
| | sample_dist, |
| | lod, |
| | sdf_network, |
| | rendering_network, |
| | background_rgb=background_rgb, |
| | background_alpha=background_alpha, |
| | background_sampled_color=background_sampled_color, |
| | alpha_inter_ratio=alpha_inter_ratio, |
| | |
| | conditional_volume=conditional_volume, |
| | conditional_valid_mask_volume=conditional_valid_mask_volume, |
| | |
| | feature_maps=feature_maps, |
| | color_maps=color_maps, |
| | w2cs=w2cs, |
| | intrinsics=intrinsics, |
| | img_wh=img_wh, |
| | query_c2w=query_c2w, |
| | if_general_rendering=if_general_rendering, |
| | if_render_with_grad=if_render_with_grad, |
| | |
| | img_index=img_index, |
| | rays_uv=rays_uv |
| | ) |
| |
|
| | color_fine = ret_fine['color'] |
| |
|
| | if self.n_outside > 0: |
| | color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) |
| | else: |
| | color_fine_mask = ret_fine['color_mask'] |
| |
|
| | weights = ret_fine['weights'] |
| | weights_sum = ret_fine['weights_sum'] |
| |
|
| | gradients = ret_fine['gradients'] |
| | mid_z_vals = ret_fine['mid_z_vals'] |
| |
|
| | |
| | depth = ret_fine['depth'] |
| | depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) |
| | variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) |
| |
|
| | |
| | pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 |
| | sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] |
| |
|
| | result = { |
| | 'depth': depth, |
| | 'color_fine': color_fine, |
| | 'color_fine_mask': color_fine_mask, |
| | 'color_outside': ret_outside['color'] if ret_outside is not None else None, |
| | 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, |
| | 'color_mlp': ret_fine['color_mlp'], |
| | 'color_mlp_mask': ret_fine['color_mlp_mask'], |
| | 'variance': variance.mean(), |
| | 'cdf_fine': ret_fine['cdf'], |
| | 'depth_variance': depth_varaince, |
| | 'weights_sum': weights_sum, |
| | 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], |
| | 'alpha_sum': ret_fine['alpha_sum'].mean(), |
| | 'alpha_mean': ret_fine['alpha_mean'], |
| | 'gradients': gradients, |
| | 'weights': weights, |
| | 'gradient_error_fine': ret_fine['gradient_error'], |
| | 'inside_sphere': ret_fine['inside_sphere'], |
| | 'sdf': ret_fine['sdf'], |
| | 'sdf_random': sdf_random, |
| | 'blended_color_patch': ret_fine['blended_color_patch'], |
| | 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], |
| | 'weights_sum_fg': ret_fine['weights_sum_fg'] |
| | } |
| |
|
| | return result |
| |
|
| | @torch.no_grad() |
| | def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): |
| | |
| | device = rays_o.device |
| | N_rays = len(rays_o) |
| | n_samples = self.n_samples * 2 |
| |
|
| | z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
| | z_vals = near + (far - near) * z_vals[None, :] |
| |
|
| | if z_vals.shape[0] == 1: |
| | z_vals = z_vals.repeat(N_rays, 1) |
| |
|
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] |
| |
|
| | sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) |
| |
|
| | new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, |
| | 200, |
| | conditional_valid_mask_volume=mask_volume, |
| | ) |
| | return new_z_vals |
| |
|
| | @torch.no_grad() |
| | def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): |
| | device = rays_o.device |
| | N_rays = len(rays_o) |
| | n_samples = self.n_samples * 2 |
| |
|
| | z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) |
| | z_vals = near + (far - near) * z_vals[None, :] |
| |
|
| | if z_vals.shape[0] == 1: |
| | z_vals = z_vals.repeat(N_rays, 1) |
| |
|
| | mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 |
| |
|
| | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] |
| |
|
| | pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( |
| | [N_rays, n_samples - 1]) |
| |
|
| | |
| | weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), |
| | 0.1 * torch.ones_like(pts_mask).to(device)) |
| |
|
| | |
| | z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() |
| | return z_samples |
| |
|
| | @torch.no_grad() |
| | def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, |
| | partial_vol_origin, voxel_size, |
| | near, far, depth_interval, d_plane_nums): |
| | """ |
| | Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) |
| | :param coords: [n, 3] int coords |
| | :param pred_depth_maps: [N_views, 1, h, w] |
| | :param proj_matrices: [N_views, 4, 4] |
| | :param partial_vol_origin: [3] |
| | :param voxel_size: 1 |
| | :param near: 1 |
| | :param far: 1 |
| | :param depth_interval: 1 |
| | :param d_plane_nums: 1 |
| | :return: |
| | """ |
| | device = pred_depth_maps.device |
| | n_views, _, sizeH, sizeW = pred_depth_maps.shape |
| |
|
| | if len(partial_vol_origin.shape) == 1: |
| | partial_vol_origin = partial_vol_origin[None, :] |
| | pts = coords * voxel_size + partial_vol_origin |
| |
|
| | rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) |
| | rs_grid = rs_grid.permute(0, 2, 1).contiguous() |
| | nV = rs_grid.shape[-1] |
| | rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) |
| |
|
| | |
| | im_p = proj_matrices @ rs_grid |
| | im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] |
| | im_x = im_x / im_z |
| | im_y = im_y / im_z |
| |
|
| | im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) |
| |
|
| | im_grid = im_grid.view(n_views, 1, -1, 2) |
| | sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', |
| | padding_mode='zeros', |
| | align_corners=True)[:, 0, 0, :] |
| | sampled_depths_valid = (sampled_depths > 0.5 * near).float() |
| | valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), |
| | far.item()) * sampled_depths_valid |
| | valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), |
| | far.item()) * sampled_depths_valid |
| |
|
| | mask = im_grid.abs() <= 1 |
| | mask = mask[:, 0] |
| | mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) |
| |
|
| | mask = mask.view(n_views, -1) |
| | mask = mask.permute(1, 0).contiguous() |
| |
|
| | mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 |
| |
|
| | return mask_final |
| |
|
| | @torch.no_grad() |
| | def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, |
| | pred_depth_maps, proj_matrices, |
| | partial_vol_origin, voxel_size, |
| | near, far, depth_interval, d_plane_nums, |
| | threshold=0.02, maximum_pts=110000): |
| | """ |
| | assume batch size == 1, from the first lod to get sparse voxels |
| | :param sdf_volume: [1, X, Y, Z] |
| | :param coords_volume: [3, X, Y, Z] |
| | :param mask_volume: [1, X, Y, Z] |
| | :param feature_volume: [C, X, Y, Z] |
| | :param threshold: |
| | :return: |
| | """ |
| | device = coords_volume.device |
| | _, dX, dY, dZ = coords_volume.shape |
| |
|
| | def prune(sdf_pts, coords_pts, mask_volume, threshold): |
| | occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) |
| | valid_coords = coords_pts[occupancy_mask] |
| |
|
| | |
| | mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, |
| | partial_vol_origin, voxel_size, |
| | near, far, depth_interval, d_plane_nums) |
| | valid_coords = valid_coords[mask_filtered] |
| |
|
| | |
| | occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) |
| |
|
| | |
| | occupancy_mask = occupancy_mask.float() |
| | occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| | occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| | occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
| |
|
| | final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
| |
|
| | return final_mask, torch.sum(final_mask.float()) |
| |
|
| | C, dX, dY, dZ = feature_volume.shape |
| | sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| | coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
| | mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| | feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
| |
|
| | |
| | |
| |
|
| | final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
| |
|
| | while (valid_num > maximum_pts) and (threshold > 0.003): |
| | threshold = threshold - 0.002 |
| | final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) |
| |
|
| | valid_coords = coords_volume[final_mask] |
| | valid_feature = feature_volume[final_mask] |
| |
|
| | valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
| | valid_coords], dim=1) |
| |
|
| | |
| | if valid_num > maximum_pts: |
| | valid_num = valid_num.long() |
| | occupancy = torch.ones([valid_num]).to(device) > 0 |
| | choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
| | replace=False) |
| | ind = torch.nonzero(occupancy).to(device) |
| | occupancy[ind[choice]] = False |
| | valid_coords = valid_coords[occupancy] |
| | valid_feature = valid_feature[occupancy] |
| |
|
| | print(threshold, "randomly sample to save memory") |
| |
|
| | return valid_coords, valid_feature |
| |
|
| | @torch.no_grad() |
| | def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, |
| | maximum_pts=110000): |
| | """ |
| | assume batch size == 1, from the first lod to get sparse voxels |
| | :param sdf_volume: [num_pts, 1] |
| | :param coords_volume: [3, X, Y, Z] |
| | :param mask_volume: [1, X, Y, Z] |
| | :param feature_volume: [C, X, Y, Z] |
| | :param threshold: |
| | :return: |
| | """ |
| |
|
| | def prune(sdf_volume, mask_volume, threshold): |
| | occupancy_mask = torch.abs(sdf_volume) < threshold |
| |
|
| | |
| | occupancy_mask = occupancy_mask.float() |
| | occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| | occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| | occupancy_mask = occupancy_mask.view(-1, 1) > 0 |
| |
|
| | final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] |
| |
|
| | return final_mask, torch.sum(final_mask.float()) |
| |
|
| | C, dX, dY, dZ = feature_volume.shape |
| | coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) |
| | mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) |
| | feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) |
| |
|
| | final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
| |
|
| | while (valid_num > maximum_pts) and (threshold > 0.003): |
| | threshold = threshold - 0.002 |
| | final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) |
| |
|
| | valid_coords = coords_volume[final_mask] |
| | valid_feature = feature_volume[final_mask] |
| |
|
| | valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, |
| | valid_coords], dim=1) |
| |
|
| | |
| | if valid_num > maximum_pts: |
| | device = sdf_volume.device |
| | valid_num = valid_num.long() |
| | occupancy = torch.ones([valid_num]).to(device) > 0 |
| | choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, |
| | replace=False) |
| | ind = torch.nonzero(occupancy).to(device) |
| | occupancy[ind[choice]] = False |
| | valid_coords = valid_coords[occupancy] |
| | valid_feature = valid_feature[occupancy] |
| |
|
| | print(threshold, "randomly sample to save memory") |
| |
|
| | return valid_coords, valid_feature |
| |
|
| | @torch.no_grad() |
| | def extract_fields(self, bound_min, bound_max, resolution, query_func, device, |
| | |
| | **kwargs |
| | ): |
| | N = 64 |
| | X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N) |
| | Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N) |
| | Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N) |
| |
|
| | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) |
| | with torch.no_grad(): |
| | for xi, xs in enumerate(X): |
| | for yi, ys in enumerate(Y): |
| | for zi, zs in enumerate(Z): |
| | xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") |
| | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) |
| |
|
| | |
| | output = query_func(pts, **kwargs) |
| | sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), |
| | len(zs)).detach().cpu().numpy() |
| |
|
| | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf |
| | return u |
| |
|
| | @torch.no_grad() |
| | def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, |
| | |
| | **kwargs |
| | ): |
| | |
| |
|
| | u = self.extract_fields(bound_min, bound_max, resolution, |
| | lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), |
| | |
| | device, |
| | |
| | **kwargs |
| | ) |
| | if occupancy_mask is not None: |
| | dX, dY, dZ = occupancy_mask.shape |
| | empty_mask = 1 - occupancy_mask |
| | empty_mask = empty_mask.view(1, 1, dX, dY, dZ) |
| | |
| | |
| | empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') |
| | empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 |
| | u[empty_mask] = -100 |
| | del empty_mask |
| |
|
| | vertices, triangles = mcubes.marching_cubes(u, threshold) |
| | b_max_np = bound_max.detach().cpu().numpy() |
| | b_min_np = bound_min.detach().cpu().numpy() |
| |
|
| | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] |
| | return vertices, triangles, u |
| |
|
| | @torch.no_grad() |
| | def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): |
| | """ |
| | extract depth maps from the density volume |
| | :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume |
| | :param c2ws: [B, 4, 4] |
| | :param H: |
| | :param W: |
| | :param near: |
| | :param far: |
| | :return: |
| | """ |
| | device = con_volume.device |
| | batch_size = intrinsics.shape[0] |
| |
|
| | with torch.no_grad(): |
| | ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), |
| | torch.linspace(0, W - 1, W), indexing="ij") |
| | p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) |
| |
|
| | intrinsics_inv = torch.inverse(intrinsics) |
| |
|
| | p = p.view(-1, 3).float().to(device) |
| | p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() |
| | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
| | rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() |
| | rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) |
| | rays_d = rays_v |
| |
|
| | rays_o = rays_o.contiguous().view(-1, 3) |
| | rays_d = rays_d.contiguous().view(-1, 3) |
| |
|
| | |
| | depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( |
| | rays_o, rays_d, |
| | near[None, :].repeat(rays_o.shape[0], 1), |
| | far[None, :].repeat(rays_o.shape[0], 1), |
| | sdf_network, con_volume |
| | ) |
| |
|
| | depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) |
| | depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) |
| |
|
| | depth_maps = torch.where(depth_masks, depth_maps, |
| | torch.zeros_like(depth_masks.float()).to(device)) |
| |
|
| | return depth_maps, depth_masks |
| |
|