| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from ast import Dict |
| | import math |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from torch_scatter import scatter_mean |
| |
|
| | from .unet_3daware import setup_unet |
| | from .conv_pointnet import ConvPointnet |
| |
|
| | from .pc_encoder import PVCNNEncoder |
| |
|
| | import einops |
| |
|
| | from .dnnlib_util import ScopedTorchProfiler, printarr |
| |
|
| | def generate_plane_features(p, c, resolution, plane='xz'): |
| | """ |
| | Args: |
| | p: (B,3,n_p) |
| | c: (B,C,n_p) |
| | """ |
| | padding = 0. |
| | c_dim = c.size(1) |
| | |
| | xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) |
| | index = coordinate2index(xy, resolution) |
| |
|
| | |
| | fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2) |
| | fea_plane = scatter_mean(c, index, out=fea_plane) |
| | fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) |
| | return fea_plane |
| |
|
| | def normalize_coordinate(p, padding=0.1, plane='xz'): |
| | ''' Normalize coordinate to [0, 1] for unit cube experiments |
| | |
| | Args: |
| | p (tensor): point |
| | padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] |
| | plane (str): plane feature type, ['xz', 'xy', 'yz'] |
| | ''' |
| | if plane == 'xz': |
| | xy = p[:, :, [0, 2]] |
| | elif plane =='xy': |
| | xy = p[:, :, [0, 1]] |
| | else: |
| | xy = p[:, :, [1, 2]] |
| |
|
| | xy_new = xy / (1 + padding + 10e-6) |
| | xy_new = xy_new + 0.5 |
| |
|
| | |
| | if xy_new.max() >= 1: |
| | xy_new[xy_new >= 1] = 1 - 10e-6 |
| | if xy_new.min() < 0: |
| | xy_new[xy_new < 0] = 0.0 |
| | return xy_new |
| |
|
| |
|
| | def coordinate2index(x, resolution): |
| | ''' Normalize coordinate to [0, 1] for unit cube experiments. |
| | Corresponds to our 3D model |
| | |
| | Args: |
| | x (tensor): coordinate |
| | reso (int): defined resolution |
| | coord_type (str): coordinate type |
| | ''' |
| | x = (x * resolution).long() |
| | index = x[:, :, 0] + resolution * x[:, :, 1] |
| | index = index[:, None, :] |
| | return index |
| |
|
| | def softclip(x, min, max, hardness=5): |
| | |
| | x = min + F.softplus(hardness*(x - min))/hardness |
| | x = max - F.softplus(-hardness*(x - max))/hardness |
| | return x |
| |
|
| |
|
| | def sample_triplane_feat(feature_triplane, normalized_pos): |
| | ''' |
| | normalized_pos [-1, 1] |
| | ''' |
| | tri_plane = torch.unbind(feature_triplane, dim=1) |
| | |
| | x_feat = F.grid_sample( |
| | tri_plane[0], |
| | torch.cat( |
| | [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]], |
| | dim=-1).unsqueeze(dim=1), padding_mode='border', |
| | align_corners=True) |
| | y_feat = F.grid_sample( |
| | tri_plane[1], |
| | torch.cat( |
| | [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]], |
| | dim=-1).unsqueeze(dim=1), padding_mode='border', |
| | align_corners=True) |
| |
|
| | z_feat = F.grid_sample( |
| | tri_plane[2], |
| | torch.cat( |
| | [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]], |
| | dim=-1).unsqueeze(dim=1), padding_mode='border', |
| | align_corners=True) |
| | final_feat = (x_feat + y_feat + z_feat) |
| | final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) |
| | return final_feat |
| |
|
| |
|
| | |
| | class TriPlanePC2Encoder(torch.nn.Module): |
| | |
| | def __init__( |
| | self, |
| | cfg, |
| | device='cuda', |
| | shape_min=-1.0, |
| | shape_length=2.0, |
| | use_2d_feat=False, |
| | |
| | |
| | ): |
| | """ |
| | Outputs latent triplane from PC input |
| | Configs: |
| | max_logsigma: (float) Soft clip upper range for logsigm |
| | min_logsigma: (float) |
| | point_encoder_type: (str) one of ['pvcnn', 'pointnet'] |
| | pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel |
| | features (instead of scattering point features) |
| | unet_cfg: (dict) |
| | z_triplane_channels: (int) output latent triplane |
| | z_triplane_resolution: (int) |
| | Args: |
| | |
| | """ |
| | |
| | super().__init__() |
| | self.device = device |
| |
|
| | self.cfg = cfg |
| |
|
| | self.shape_min = shape_min |
| | self.shape_length = shape_length |
| |
|
| | self.z_triplane_resolution = cfg.z_triplane_resolution |
| | z_triplane_channels = cfg.z_triplane_channels |
| |
|
| | point_encoder_out_dim = z_triplane_channels |
| |
|
| | in_channels = 6 |
| | |
| | if cfg.point_encoder_type == 'pvcnn': |
| | self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, |
| | device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) |
| | elif cfg.point_encoder_type == 'pointnet': |
| | |
| | self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, |
| | dim=in_channels, hidden_dim=32, |
| | plane_resolution=self.z_triplane_resolution, |
| | padding=0) |
| | else: |
| | raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented") |
| |
|
| | if cfg.unet_cfg.enabled: |
| | self.unet_encoder = setup_unet( |
| | output_channels=point_encoder_out_dim, |
| | input_channels=point_encoder_out_dim, |
| | unet_cfg=cfg.unet_cfg) |
| | else: |
| | self.unet_encoder = None |
| |
|
| | |
| | def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict: |
| | |
| | point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length |
| | point_cloud_xyz = point_cloud_xyz - 0.5 |
| | point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1) |
| |
|
| | if self.cfg.point_encoder_type == 'pvcnn': |
| | if mv_feat is not None: |
| | pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx) |
| | else: |
| | pc_feat, points_feat = self.pc_encoder(point_cloud) |
| | if self.cfg.use_point_scatter: |
| | |
| | points_feat_ = points_feat[0] |
| | |
| | pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, |
| | resolution=self.z_triplane_resolution, plane='xy') |
| | pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_, |
| | resolution=self.z_triplane_resolution, plane='yz') |
| | pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_, |
| | resolution=self.z_triplane_resolution, plane='xz') |
| | pc_feat = pc_feat[0] |
| |
|
| | else: |
| | pc_feat = pc_feat[0] |
| | sf = self.z_triplane_resolution//32 |
| |
|
| | pc_feat_1 = torch.mean(pc_feat, dim=-1) |
| | pc_feat_2 = torch.mean(pc_feat, dim=-3) |
| | pc_feat_3 = torch.mean(pc_feat, dim=-2) |
| |
|
| | |
| | pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf) |
| | pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) |
| | pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) |
| | elif self.cfg.point_encoder_type == 'pointnet': |
| | assert self.cfg.use_point_scatter |
| | |
| | pc_feat = self.pc_encoder(point_cloud) |
| | pc_feat_1 = pc_feat['xy'] |
| | pc_feat_2 = pc_feat['yz'] |
| | pc_feat_3 = pc_feat['xz'] |
| | else: |
| | raise NotImplementedError() |
| |
|
| | if self.unet_encoder is not None: |
| | |
| | |
| | pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) |
| | |
| | |
| | pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) |
| | pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1) |
| |
|
| | return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) |
| | |
| | def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None): |
| | return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx) |