| |
| |
| |
| |
| |
| |
| |
|
|
| from ast import Dict |
| import math |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch_scatter import scatter_mean |
|
|
| from .unet_3daware import setup_unet |
| from .conv_pointnet import ConvPointnet |
|
|
| from .pc_encoder import PVCNNEncoder |
|
|
| import einops |
|
|
| from .dnnlib_util import ScopedTorchProfiler, printarr |
|
|
| def generate_plane_features(p, c, resolution, plane='xz'): |
| """ |
| Args: |
| p: (B,3,n_p) |
| c: (B,C,n_p) |
| """ |
| padding = 0. |
| c_dim = c.size(1) |
| |
| xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) |
| index = coordinate2index(xy, resolution) |
|
|
| |
| fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2) |
| fea_plane = scatter_mean(c, index, out=fea_plane) |
| fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) |
| return fea_plane |
|
|
| def normalize_coordinate(p, padding=0.1, plane='xz'): |
| ''' Normalize coordinate to [0, 1] for unit cube experiments |
| |
| Args: |
| p (tensor): point |
| padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] |
| plane (str): plane feature type, ['xz', 'xy', 'yz'] |
| ''' |
| if plane == 'xz': |
| xy = p[:, :, [0, 2]] |
| elif plane =='xy': |
| xy = p[:, :, [0, 1]] |
| else: |
| xy = p[:, :, [1, 2]] |
|
|
| xy_new = xy / (1 + padding + 10e-6) |
| xy_new = xy_new + 0.5 |
|
|
| |
| if xy_new.max() >= 1: |
| xy_new[xy_new >= 1] = 1 - 10e-6 |
| if xy_new.min() < 0: |
| xy_new[xy_new < 0] = 0.0 |
| return xy_new |
|
|
|
|
| def coordinate2index(x, resolution): |
| ''' Normalize coordinate to [0, 1] for unit cube experiments. |
| Corresponds to our 3D model |
| |
| Args: |
| x (tensor): coordinate |
| reso (int): defined resolution |
| coord_type (str): coordinate type |
| ''' |
| x = (x * resolution).long() |
| index = x[:, :, 0] + resolution * x[:, :, 1] |
| index = index[:, None, :] |
| return index |
|
|
| def softclip(x, min, max, hardness=5): |
| |
| x = min + F.softplus(hardness*(x - min))/hardness |
| x = max - F.softplus(-hardness*(x - max))/hardness |
| return x |
|
|
|
|
| def sample_triplane_feat(feature_triplane, normalized_pos): |
| ''' |
| normalized_pos [-1, 1] |
| ''' |
| tri_plane = torch.unbind(feature_triplane, dim=1) |
| |
| x_feat = F.grid_sample( |
| tri_plane[0], |
| torch.cat( |
| [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]], |
| dim=-1).unsqueeze(dim=1), padding_mode='border', |
| align_corners=True) |
| y_feat = F.grid_sample( |
| tri_plane[1], |
| torch.cat( |
| [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]], |
| dim=-1).unsqueeze(dim=1), padding_mode='border', |
| align_corners=True) |
|
|
| z_feat = F.grid_sample( |
| tri_plane[2], |
| torch.cat( |
| [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]], |
| dim=-1).unsqueeze(dim=1), padding_mode='border', |
| align_corners=True) |
| final_feat = (x_feat + y_feat + z_feat) |
| final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) |
| return final_feat |
|
|
|
|
| |
| class TriPlanePC2Encoder(torch.nn.Module): |
| |
| def __init__( |
| self, |
| cfg, |
| device='cuda', |
| shape_min=-1.0, |
| shape_length=2.0, |
| use_2d_feat=False, |
| |
| |
| ): |
| """ |
| Outputs latent triplane from PC input |
| Configs: |
| max_logsigma: (float) Soft clip upper range for logsigm |
| min_logsigma: (float) |
| point_encoder_type: (str) one of ['pvcnn', 'pointnet'] |
| pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel |
| features (instead of scattering point features) |
| unet_cfg: (dict) |
| z_triplane_channels: (int) output latent triplane |
| z_triplane_resolution: (int) |
| Args: |
| |
| """ |
| |
| super().__init__() |
| self.device = device |
|
|
| self.cfg = cfg |
|
|
| self.shape_min = shape_min |
| self.shape_length = shape_length |
|
|
| self.z_triplane_resolution = cfg.z_triplane_resolution |
| z_triplane_channels = cfg.z_triplane_channels |
|
|
| point_encoder_out_dim = z_triplane_channels |
|
|
| in_channels = 6 |
| |
| if cfg.point_encoder_type == 'pvcnn': |
| self.pc_encoder = PVCNNEncoder(point_encoder_out_dim, |
| device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) |
| elif cfg.point_encoder_type == 'pointnet': |
| |
| self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim, |
| dim=in_channels, hidden_dim=32, |
| plane_resolution=self.z_triplane_resolution, |
| padding=0) |
| else: |
| raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented") |
|
|
| if cfg.unet_cfg.enabled: |
| self.unet_encoder = setup_unet( |
| output_channels=point_encoder_out_dim, |
| input_channels=point_encoder_out_dim, |
| unet_cfg=cfg.unet_cfg) |
| else: |
| self.unet_encoder = None |
|
|
| |
| def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict: |
| |
| point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length |
| point_cloud_xyz = point_cloud_xyz - 0.5 |
| point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1) |
|
|
| if self.cfg.point_encoder_type == 'pvcnn': |
| if mv_feat is not None: |
| pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx) |
| else: |
| pc_feat, points_feat = self.pc_encoder(point_cloud) |
| if self.cfg.use_point_scatter: |
| |
| points_feat_ = points_feat[0] |
| |
| pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_, |
| resolution=self.z_triplane_resolution, plane='xy') |
| pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_, |
| resolution=self.z_triplane_resolution, plane='yz') |
| pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_, |
| resolution=self.z_triplane_resolution, plane='xz') |
| pc_feat = pc_feat[0] |
|
|
| else: |
| pc_feat = pc_feat[0] |
| sf = self.z_triplane_resolution//32 |
|
|
| pc_feat_1 = torch.mean(pc_feat, dim=-1) |
| pc_feat_2 = torch.mean(pc_feat, dim=-3) |
| pc_feat_3 = torch.mean(pc_feat, dim=-2) |
|
|
| |
| pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf) |
| pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) |
| pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf) |
| elif self.cfg.point_encoder_type == 'pointnet': |
| assert self.cfg.use_point_scatter |
| |
| pc_feat = self.pc_encoder(point_cloud) |
| pc_feat_1 = pc_feat['xy'] |
| pc_feat_2 = pc_feat['yz'] |
| pc_feat_3 = pc_feat['xz'] |
| else: |
| raise NotImplementedError() |
|
|
| if self.unet_encoder is not None: |
| |
| |
| pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) |
| |
| |
| pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre) |
| pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1) |
|
|
| return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1) |
| |
| def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None): |
| return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx) |