| """Feature Extraction and Parameter Prediction networks |
| """ |
| import logging |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .. utils import sample_and_group_multi |
|
|
| _raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4} |
| _raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2} |
|
|
|
|
| def get_prepool(in_dim, out_dim): |
| """Shared FC part in PointNet before max pooling""" |
| net = nn.Sequential( |
| nn.Conv2d(in_dim, out_dim // 2, 1), |
| nn.GroupNorm(8, out_dim // 2), |
| nn.ReLU(), |
| nn.Conv2d(out_dim // 2, out_dim // 2, 1), |
| nn.GroupNorm(8, out_dim // 2), |
| nn.ReLU(), |
| nn.Conv2d(out_dim // 2, out_dim, 1), |
| nn.GroupNorm(8, out_dim), |
| nn.ReLU(), |
| ) |
| return net |
|
|
|
|
| def get_postpool(in_dim, out_dim): |
| """Linear layers in PointNet after max pooling |
| |
| Args: |
| in_dim: Number of input channels |
| out_dim: Number of output channels. Typically smaller than in_dim |
| |
| """ |
| net = nn.Sequential( |
| nn.Conv1d(in_dim, in_dim, 1), |
| nn.GroupNorm(8, in_dim), |
| nn.ReLU(), |
| nn.Conv1d(in_dim, out_dim, 1), |
| nn.GroupNorm(8, out_dim), |
| nn.ReLU(), |
| nn.Conv1d(out_dim, out_dim, 1), |
| ) |
|
|
| return net |
|
|
|
|
| class PPFNet(nn.Module): |
| """Feature extraction Module that extracts hybrid features""" |
| def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64): |
| super().__init__() |
|
|
| self._logger = logging.getLogger(self.__class__.__name__) |
| self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims)) |
| self.radius = radius |
| self.n_sample = num_neighbors |
|
|
| self.features = sorted(features, key=lambda f: _raw_features_order[f]) |
| self._logger.info('Feature extraction using features {}'.format(', '.join(self.features))) |
|
|
| |
| raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) |
| self.prepool = get_prepool(raw_dim, emb_dims * 2) |
| self.postpool = get_postpool(emb_dims * 2, emb_dims) |
|
|
| def forward(self, xyz, normals): |
| """Forward pass of the feature extraction network |
| |
| Args: |
| xyz: (B, N, 3) |
| normals: (B, N, 3) |
| |
| Returns: |
| cluster features (B, N, C) |
| |
| """ |
| features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals) |
| features['xyz'] = features['xyz'][:, :, None, :] |
|
|
| |
| concat = [] |
| for i in range(len(self.features)): |
| f = self.features[i] |
| expanded = (features[f]).expand(-1, -1, self.n_sample, -1) |
| concat.append(expanded) |
| fused_input_feat = torch.cat(concat, -1) |
|
|
| |
| new_feat = fused_input_feat.permute(0, 3, 2, 1) |
| new_feat = self.prepool(new_feat) |
|
|
| pooled_feat = torch.max(new_feat, 2)[0] |
|
|
| post_feat = self.postpool(pooled_feat) |
| cluster_feat = post_feat.permute(0, 2, 1) |
| cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True) |
|
|
| return cluster_feat |