YasiiKB's picture
initial commit
97aa5af verified
"""Feature Extraction and Parameter Prediction networks
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. utils import sample_and_group_multi
_raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4}
_raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2}
def get_prepool(in_dim, out_dim):
"""Shared FC part in PointNet before max pooling"""
net = nn.Sequential(
nn.Conv2d(in_dim, out_dim // 2, 1),
nn.GroupNorm(8, out_dim // 2),
nn.ReLU(),
nn.Conv2d(out_dim // 2, out_dim // 2, 1),
nn.GroupNorm(8, out_dim // 2),
nn.ReLU(),
nn.Conv2d(out_dim // 2, out_dim, 1),
nn.GroupNorm(8, out_dim),
nn.ReLU(),
)
return net
def get_postpool(in_dim, out_dim):
"""Linear layers in PointNet after max pooling
Args:
in_dim: Number of input channels
out_dim: Number of output channels. Typically smaller than in_dim
"""
net = nn.Sequential(
nn.Conv1d(in_dim, in_dim, 1),
nn.GroupNorm(8, in_dim),
nn.ReLU(),
nn.Conv1d(in_dim, out_dim, 1),
nn.GroupNorm(8, out_dim),
nn.ReLU(),
nn.Conv1d(out_dim, out_dim, 1),
)
return net
class PPFNet(nn.Module):
"""Feature extraction Module that extracts hybrid features"""
def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64):
super().__init__()
self._logger = logging.getLogger(self.__class__.__name__)
self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims))
self.radius = radius
self.n_sample = num_neighbors
self.features = sorted(features, key=lambda f: _raw_features_order[f])
self._logger.info('Feature extraction using features {}'.format(', '.join(self.features)))
# Layers
raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) # number of channels after concat
self.prepool = get_prepool(raw_dim, emb_dims * 2)
self.postpool = get_postpool(emb_dims * 2, emb_dims)
def forward(self, xyz, normals):
"""Forward pass of the feature extraction network
Args:
xyz: (B, N, 3)
normals: (B, N, 3)
Returns:
cluster features (B, N, C)
"""
features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals)
features['xyz'] = features['xyz'][:, :, None, :]
# Gate and concat
concat = []
for i in range(len(self.features)):
f = self.features[i]
expanded = (features[f]).expand(-1, -1, self.n_sample, -1)
concat.append(expanded)
fused_input_feat = torch.cat(concat, -1)
# Prepool_FC, pool, postpool-FC
new_feat = fused_input_feat.permute(0, 3, 2, 1) # [B, 10, n_sample, N]
new_feat = self.prepool(new_feat)
pooled_feat = torch.max(new_feat, 2)[0] # Max pooling (B, C, N)
post_feat = self.postpool(pooled_feat) # Post pooling dense layers
cluster_feat = post_feat.permute(0, 2, 1)
cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True)
return cluster_feat # (B, N, C)