Spaces:
Paused
Paused
| import os | |
| import time | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| from torch.utils.cpp_extension import load | |
| parent_dir = os.path.dirname(os.path.abspath(__file__)) | |
| render_utils_cuda = load( | |
| name='render_utils_cuda', | |
| sources=[ | |
| os.path.join(parent_dir, path) | |
| for path in ['cuda/render_utils.cpp', 'cuda/render_utils_kernel.cu']], | |
| verbose=True) | |
| total_variation_cuda = load( | |
| name='total_variation_cuda', | |
| sources=[ | |
| os.path.join(parent_dir, path) | |
| for path in ['cuda/total_variation.cpp', 'cuda/total_variation_kernel.cu']], | |
| verbose=True) | |
| def create_grid(type, **kwargs): | |
| if type == 'DenseGrid': | |
| return DenseGrid(**kwargs) | |
| elif type == 'TensoRFGrid': | |
| return TensoRFGrid(**kwargs) | |
| else: | |
| raise NotImplementedError | |
| ''' Dense 3D grid | |
| ''' | |
| class DenseGrid(nn.Module): | |
| def __init__(self, channels, world_size, xyz_min, xyz_max, **kwargs): | |
| super(DenseGrid, self).__init__() | |
| self.channels = channels | |
| self.world_size = world_size | |
| self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
| self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
| self.grid = nn.Parameter(torch.zeros([1, channels, *world_size])) | |
| print(self.xyz_min, self.xyz_max, self.world_size) | |
| def forward(self, xyz): | |
| ''' | |
| xyz: global coordinates to query | |
| ''' | |
| shape = xyz.shape[:-1] | |
| xyz = xyz.reshape(1,1,1,-1,3) | |
| ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 | |
| out = F.grid_sample(self.grid, ind_norm, mode='bilinear', align_corners=True) | |
| out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels) | |
| if self.channels == 1: | |
| out = out.squeeze(-1) | |
| return out | |
| def scale_volume_grid(self, new_world_size): | |
| if self.channels == 0: | |
| self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size])) | |
| else: | |
| self.grid = nn.Parameter( | |
| F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True)) | |
| def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
| '''Add gradients by total variation loss in-place''' | |
| total_variation_cuda.total_variation_add_grad( | |
| self.grid, self.grid.grad, wx, wy, wz, dense_mode) | |
| def get_dense_grid(self): | |
| return self.grid | |
| def __isub__(self, val): | |
| self.grid.data -= val | |
| return self | |
| def extra_repr(self): | |
| return f'channels={self.channels}, world_size={self.world_size.tolist()}' | |
| # ''' Utilize autograd for 3D mask generation | |
| # ''' | |
| # class ConstrainedGrad(torch.autograd.Function): | |
| # @staticmethod | |
| # def forward(ctx, inp): | |
| # if inp.requires_grad: | |
| # ctx.save_for_backward(inp) | |
| # return inp | |
| # @staticmethod | |
| # @torch.autograd.function.once_differentiable | |
| # def backward(ctx, grad_back): | |
| # ''' | |
| # grad_back should be [0,1] | |
| # ''' | |
| # val = ctx.saved_tensors[0] | |
| # return grad_back * (1-x), None, None | |
| # ''' Dense 3D grid for 3D mask | |
| # ''' | |
| # class MaskDenseGrid(nn.Module): | |
| # def __init__(self, channels, world_size, xyz_min, xyz_max, **kwargs): | |
| # super(MaskDenseGrid, self).__init__() | |
| # self.channels = channels | |
| # self.world_size = world_size | |
| # self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
| # self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
| # self.grid = nn.Parameter(torch.zeros([1, channels, *world_size])) | |
| # def forward(self, xyz): | |
| # ''' | |
| # xyz: global coordinates to query | |
| # ''' | |
| # shape = xyz.shape[:-1] | |
| # xyz = xyz.reshape(1,1,1,-1,3) | |
| # ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 | |
| # # modify the backward gradients | |
| # out = F.grid_sample(ConstrainedGrad.apply(self.grid), ind_norm, mode='bilinear', align_corners=True) | |
| # out = out.reshape(self.channels,-1).T.reshape(*shape,self.channels) | |
| # if self.channels == 1: | |
| # out = out.squeeze(-1) | |
| # return out | |
| # @torch.no_grad() | |
| # def scale_volume_grid(self, new_world_size): | |
| # if self.channels == 0: | |
| # self.grid = nn.Parameter(torch.zeros([1, self.channels, *new_world_size])) | |
| # else: | |
| # self.grid = nn.Parameter( | |
| # F.interpolate(self.grid.data, size=tuple(new_world_size), mode='trilinear', align_corners=True)) | |
| # self.world_size = new_world_size | |
| # @torch.no_grad() | |
| # def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
| # '''Add gradients by total variation loss in-place''' | |
| # total_variation_cuda.total_variation_add_grad( | |
| # self.grid, self.grid.grad, wx, wy, wz, dense_mode) | |
| # @torch.no_grad() | |
| # def get_dense_grid(self): | |
| # return self.grid | |
| # @torch.no_grad() | |
| # def __isub__(self, val): | |
| # self.grid.data -= val | |
| # return self | |
| # def extra_repr(self): | |
| # return f'channels={self.channels}, world_size={self.world_size.tolist()}' | |
| ''' Vector-Matrix decomposited grid | |
| See TensoRF: Tensorial Radiance Fields (https://arxiv.org/abs/2203.09517) | |
| ''' | |
| class TensoRFGrid(nn.Module): | |
| def __init__(self, channels, world_size, xyz_min, xyz_max, config): | |
| super(TensoRFGrid, self).__init__() | |
| self.channels = channels | |
| self.world_size = world_size | |
| self.config = config | |
| self.register_buffer('xyz_min', torch.Tensor(xyz_min)) | |
| self.register_buffer('xyz_max', torch.Tensor(xyz_max)) | |
| X, Y, Z = world_size | |
| R = config['n_comp'] | |
| Rxy = config.get('n_comp_xy', R) | |
| self.xy_plane = nn.Parameter(torch.randn([1, Rxy, X, Y]) * 0.1) | |
| self.xz_plane = nn.Parameter(torch.randn([1, R, X, Z]) * 0.1) | |
| self.yz_plane = nn.Parameter(torch.randn([1, R, Y, Z]) * 0.1) | |
| self.x_vec = nn.Parameter(torch.randn([1, R, X, 1]) * 0.1) | |
| self.y_vec = nn.Parameter(torch.randn([1, R, Y, 1]) * 0.1) | |
| self.z_vec = nn.Parameter(torch.randn([1, Rxy, Z, 1]) * 0.1) | |
| if self.channels > 1: | |
| self.f_vec = nn.Parameter(torch.ones([R+R+Rxy, channels])) | |
| nn.init.kaiming_uniform_(self.f_vec, a=np.sqrt(5)) | |
| def forward(self, xyz): | |
| ''' | |
| xyz: global coordinates to query | |
| ''' | |
| shape = xyz.shape[:-1] | |
| xyz = xyz.reshape(1,1,-1,3) | |
| ind_norm = (xyz - self.xyz_min) / (self.xyz_max - self.xyz_min) * 2 - 1 | |
| ind_norm = torch.cat([ind_norm, torch.zeros_like(ind_norm[...,[0]])], dim=-1) | |
| if self.channels > 1: | |
| out = compute_tensorf_feat( | |
| self.xy_plane, self.xz_plane, self.yz_plane, | |
| self.x_vec, self.y_vec, self.z_vec, self.f_vec, ind_norm) | |
| out = out.reshape(*shape,self.channels) | |
| else: | |
| out = compute_tensorf_val( | |
| self.xy_plane, self.xz_plane, self.yz_plane, | |
| self.x_vec, self.y_vec, self.z_vec, ind_norm) | |
| out = out.reshape(*shape) | |
| return out | |
| def scale_volume_grid(self, new_world_size): | |
| if self.channels == 0: | |
| return | |
| X, Y, Z = new_world_size | |
| self.xy_plane = nn.Parameter(F.interpolate(self.xy_plane.data, size=[X,Y], mode='bilinear', align_corners=True)) | |
| self.xz_plane = nn.Parameter(F.interpolate(self.xz_plane.data, size=[X,Z], mode='bilinear', align_corners=True)) | |
| self.yz_plane = nn.Parameter(F.interpolate(self.yz_plane.data, size=[Y,Z], mode='bilinear', align_corners=True)) | |
| self.x_vec = nn.Parameter(F.interpolate(self.x_vec.data, size=[X,1], mode='bilinear', align_corners=True)) | |
| self.y_vec = nn.Parameter(F.interpolate(self.y_vec.data, size=[Y,1], mode='bilinear', align_corners=True)) | |
| self.z_vec = nn.Parameter(F.interpolate(self.z_vec.data, size=[Z,1], mode='bilinear', align_corners=True)) | |
| def total_variation_add_grad(self, wx, wy, wz, dense_mode): | |
| '''Add gradients by total variation loss in-place''' | |
| loss = wx * F.smooth_l1_loss(self.xy_plane[:,:,1:], self.xy_plane[:,:,:-1], reduction='sum') +\ | |
| wy * F.smooth_l1_loss(self.xy_plane[:,:,:,1:], self.xy_plane[:,:,:,:-1], reduction='sum') +\ | |
| wx * F.smooth_l1_loss(self.xz_plane[:,:,1:], self.xz_plane[:,:,:-1], reduction='sum') +\ | |
| wz * F.smooth_l1_loss(self.xz_plane[:,:,:,1:], self.xz_plane[:,:,:,:-1], reduction='sum') +\ | |
| wy * F.smooth_l1_loss(self.yz_plane[:,:,1:], self.yz_plane[:,:,:-1], reduction='sum') +\ | |
| wz * F.smooth_l1_loss(self.yz_plane[:,:,:,1:], self.yz_plane[:,:,:,:-1], reduction='sum') +\ | |
| wx * F.smooth_l1_loss(self.x_vec[:,:,1:], self.x_vec[:,:,:-1], reduction='sum') +\ | |
| wy * F.smooth_l1_loss(self.y_vec[:,:,1:], self.y_vec[:,:,:-1], reduction='sum') +\ | |
| wz * F.smooth_l1_loss(self.z_vec[:,:,1:], self.z_vec[:,:,:-1], reduction='sum') | |
| loss /= 6 | |
| loss.backward() | |
| def get_dense_grid(self): | |
| if self.channels > 1: | |
| feat = torch.cat([ | |
| torch.einsum('rxy,rz->rxyz', self.xy_plane[0], self.z_vec[0,:,:,0]), | |
| torch.einsum('rxz,ry->rxyz', self.xz_plane[0], self.y_vec[0,:,:,0]), | |
| torch.einsum('ryz,rx->rxyz', self.yz_plane[0], self.x_vec[0,:,:,0]), | |
| ]) | |
| grid = torch.einsum('rxyz,rc->cxyz', feat, self.f_vec)[None] | |
| else: | |
| grid = torch.einsum('rxy,rz->xyz', self.xy_plane[0], self.z_vec[0,:,:,0]) + \ | |
| torch.einsum('rxz,ry->xyz', self.xz_plane[0], self.y_vec[0,:,:,0]) + \ | |
| torch.einsum('ryz,rx->xyz', self.yz_plane[0], self.x_vec[0,:,:,0]) | |
| grid = grid[None,None] | |
| return grid | |
| def extra_repr(self): | |
| return f'channels={self.channels}, world_size={self.world_size.tolist()}, n_comp={self.config["n_comp"]}' | |
| def compute_tensorf_feat(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, f_vec, ind_norm): | |
| # Interp feature (feat shape: [n_pts, n_comp]) | |
| xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| # Aggregate components | |
| feat = torch.cat([ | |
| xy_feat * z_feat, | |
| xz_feat * y_feat, | |
| yz_feat * x_feat, | |
| ], dim=-1) | |
| feat = torch.mm(feat, f_vec) | |
| return feat | |
| def compute_tensorf_val(xy_plane, xz_plane, yz_plane, x_vec, y_vec, z_vec, ind_norm): | |
| # Interp feature (feat shape: [n_pts, n_comp]) | |
| xy_feat = F.grid_sample(xy_plane, ind_norm[:,:,:,[1,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| xz_feat = F.grid_sample(xz_plane, ind_norm[:,:,:,[2,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| yz_feat = F.grid_sample(yz_plane, ind_norm[:,:,:,[2,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| x_feat = F.grid_sample(x_vec, ind_norm[:,:,:,[3,0]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| y_feat = F.grid_sample(y_vec, ind_norm[:,:,:,[3,1]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| z_feat = F.grid_sample(z_vec, ind_norm[:,:,:,[3,2]], mode='bilinear', align_corners=True).flatten(0,2).T | |
| # Aggregate components | |
| feat = (xy_feat * z_feat).sum(-1) + (xz_feat * y_feat).sum(-1) + (yz_feat * x_feat).sum(-1) | |
| return feat | |
| ''' Mask grid | |
| It supports query for the known free space and unknown space. | |
| ''' | |
| class MaskGrid(nn.Module): | |
| def __init__(self, path=None, mask_cache_thres=None, mask=None, xyz_min=None, xyz_max=None): | |
| super(MaskGrid, self).__init__() | |
| if path is not None: | |
| st = torch.load(path) | |
| self.mask_cache_thres = mask_cache_thres | |
| density = F.max_pool3d(st['model_state_dict']['density.grid'], kernel_size=3, padding=1, stride=1) | |
| alpha = 1 - torch.exp(-F.softplus(density + st['model_state_dict']['act_shift']) * st['model_kwargs']['voxel_size_ratio']) | |
| mask = (alpha >= self.mask_cache_thres).squeeze(0).squeeze(0) | |
| xyz_min = torch.Tensor(st['model_kwargs']['xyz_min']) | |
| xyz_max = torch.Tensor(st['model_kwargs']['xyz_max']) | |
| else: | |
| mask = mask.bool() | |
| xyz_min = torch.Tensor(xyz_min) | |
| xyz_max = torch.Tensor(xyz_max) | |
| self.register_buffer('mask', mask) | |
| xyz_len = xyz_max - xyz_min | |
| self.register_buffer('xyz2ijk_scale', (torch.Tensor(list(mask.shape)) - 1) / xyz_len) | |
| self.register_buffer('xyz2ijk_shift', -xyz_min * self.xyz2ijk_scale) | |
| def forward(self, xyz): | |
| '''Skip know freespace | |
| @xyz: [..., 3] the xyz in global coordinate. | |
| ''' | |
| shape = xyz.shape[:-1] | |
| xyz = xyz.reshape(-1, 3) | |
| mask = render_utils_cuda.maskcache_lookup(self.mask, xyz, self.xyz2ijk_scale, self.xyz2ijk_shift) | |
| mask = mask.reshape(shape) | |
| return mask | |
| def extra_repr(self): | |
| return f'mask.shape=list(self.mask.shape)' | |
| def get_dense_grid_batch_processing(tensorf: TensoRFGrid): | |
| ''' | |
| Expects the tensorf to be already on device and processes it on device batchwise. | |
| Not transferring from cpu to avoid repeated transfers from cpu to device | |
| Returns the grid which is also on device | |
| ''' | |
| # we will construct it 3d column wise | |
| # result_grid = torch.zeros([1, tensorf.channels, *tensorf.world_size], dtype=tensorf.xy_plane.dtype).cpu() | |
| start_time = time.time() | |
| # result_grid = torch.stack([torch.zeros([1, *tensorf.world_size], dtype=tensorf.x_vec.dtype).cpu() for _ in range(tensorf.channels)], dim=1) | |
| # print(tensorf.channels, tensorf.world_size) | |
| # result_grid = torch.zeros([1, tensorf.channels, *tensorf.world_size], dtype=tensorf.x_vec.dtype) | |
| # debugging | |
| result_grid = torch.zeros([1, 64, *tensorf.world_size], dtype=tensorf.x_vec.dtype) | |
| print("Time taken for initializing the grid", time.time() - start_time) | |
| # created y batches just in case if needed | |
| batch_size_x = 35 | |
| batch_size_y = 35 | |
| batch_size_z = 35 | |
| for start_x in range(0, tensorf.world_size[0], batch_size_x): | |
| end_x = start_x + batch_size_x | |
| for start_y in range(0, tensorf.world_size[1], batch_size_y): | |
| end_y = start_y + batch_size_y | |
| for start_z in range(0, tensorf.world_size[2], batch_size_z): | |
| end_z = start_z + batch_size_z | |
| feat = torch.cat([ | |
| torch.einsum('rxy,rz->rxyz', tensorf.xy_plane[0, :, start_x:end_x, start_y:end_y], tensorf.z_vec[0,:,start_z:end_z,0]), | |
| torch.einsum('rxz,ry->rxyz', tensorf.xz_plane[0, :, start_x:end_x, start_z:end_z], tensorf.y_vec[0,:,start_y:end_y,0]), | |
| torch.einsum('ryz,rx->rxyz', tensorf.yz_plane[0, :, start_y:end_y, start_z:end_z], tensorf.x_vec[0,:,start_x:end_x,0]), | |
| ]) | |
| sub_grid = torch.einsum('rxyz,rc->cxyz', feat, tensorf.f_vec)[None] | |
| result_grid[:, :, start_x:end_x, start_y:end_y, start_z:end_z] = sub_grid[:,:64,:,:,:] | |
| return result_grid | |
| def reconstruct_feature_grid(render_viewpoints_kwargs): | |
| model = render_viewpoints_kwargs['model'] | |
| f_k0 = model.f_k0.cuda() | |
| fg = get_dense_grid_batch_processing(f_k0).cuda() | |
| fg_kmeans = fg.clone() | |
| fg_kmeans = fg_kmeans.squeeze(0).permute(1, 2, 3, 0) # x, y, z, 64 | |
| fg_kmeans = fg_kmeans.reshape(-1, 64) | |
| fg_kmeans = fg_kmeans.cpu().contiguous() | |
| return torch.nn.functional.pad(fg, [1] * 6), fg_kmeans | |
| if __name__ == "__main__": | |
| with torch.no_grad(): | |
| print("Testing whether the outputted grid is the correct or not.") | |
| tensorf = TensoRFGrid(64, torch.tensor([100, 100, 100]), 0, 1, {'n_comp': 64}) | |
| tensorf = tensorf.cuda() | |
| start_time = time.time() | |
| grid1 = tensorf.get_dense_grid().cpu() | |
| print("Time taken for full gpu implementation", time.time() - start_time) | |
| grid2 = get_dense_grid_batch_processing(tensorf) | |
| assert grid1.isclose(grid2, atol=1e-7).all() | |
| del grid1, grid2, tensorf | |
| torch.cuda.empty_cache() | |
| tensorf = TensoRFGrid(64, torch.tensor([320, 320, 320]), 0, 1, {'n_comp': 64}) | |
| tensorf = tensorf.cuda() | |
| start_time = time.time() | |
| grid = get_dense_grid_batch_processing(tensorf) | |
| print("Time taken to reconstruct the grid", time.time() - start_time) | |
| print("Program over.") |