File size: 2,894 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | """Feature Extraction and Parameter Prediction networks
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. utils import sample_and_group_multi
_raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4}
_raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2}
def get_prepool(in_dim, out_dim):
"""Shared FC part in PointNet before max pooling"""
net = nn.Sequential(
nn.Conv2d(in_dim, out_dim // 2, 1),
nn.GroupNorm(8, out_dim // 2),
nn.ReLU(),
nn.Conv2d(out_dim // 2, out_dim // 2, 1),
nn.GroupNorm(8, out_dim // 2),
nn.ReLU(),
nn.Conv2d(out_dim // 2, out_dim, 1),
nn.GroupNorm(8, out_dim),
nn.ReLU(),
)
return net
def get_postpool(in_dim, out_dim):
"""Linear layers in PointNet after max pooling
Args:
in_dim: Number of input channels
out_dim: Number of output channels. Typically smaller than in_dim
"""
net = nn.Sequential(
nn.Conv1d(in_dim, in_dim, 1),
nn.GroupNorm(8, in_dim),
nn.ReLU(),
nn.Conv1d(in_dim, out_dim, 1),
nn.GroupNorm(8, out_dim),
nn.ReLU(),
nn.Conv1d(out_dim, out_dim, 1),
)
return net
class PPFNet(nn.Module):
"""Feature extraction Module that extracts hybrid features"""
def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64):
super().__init__()
self._logger = logging.getLogger(self.__class__.__name__)
self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims))
self.radius = radius
self.n_sample = num_neighbors
self.features = sorted(features, key=lambda f: _raw_features_order[f])
self._logger.info('Feature extraction using features {}'.format(', '.join(self.features)))
# Layers
raw_dim = np.sum([_raw_features_sizes[f] for f in self.features]) # number of channels after concat
self.prepool = get_prepool(raw_dim, emb_dims * 2)
self.postpool = get_postpool(emb_dims * 2, emb_dims)
def forward(self, xyz, normals):
"""Forward pass of the feature extraction network
Args:
xyz: (B, N, 3)
normals: (B, N, 3)
Returns:
cluster features (B, N, C)
"""
features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals)
features['xyz'] = features['xyz'][:, :, None, :]
# Gate and concat
concat = []
for i in range(len(self.features)):
f = self.features[i]
expanded = (features[f]).expand(-1, -1, self.n_sample, -1)
concat.append(expanded)
fused_input_feat = torch.cat(concat, -1)
# Prepool_FC, pool, postpool-FC
new_feat = fused_input_feat.permute(0, 3, 2, 1) # [B, 10, n_sample, N]
new_feat = self.prepool(new_feat)
pooled_feat = torch.max(new_feat, 2)[0] # Max pooling (B, C, N)
post_feat = self.postpool(pooled_feat) # Post pooling dense layers
cluster_feat = post_feat.permute(0, 2, 1)
cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True)
return cluster_feat # (B, N, C) |