File size: 2,894 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Feature Extraction and Parameter Prediction networks
"""
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .. utils import sample_and_group_multi

_raw_features_sizes = {'xyz': 3, 'dxyz': 3, 'ppf': 4}
_raw_features_order = {'xyz': 0, 'dxyz': 1, 'ppf': 2}


def get_prepool(in_dim, out_dim):
	"""Shared FC part in PointNet before max pooling"""
	net = nn.Sequential(
		nn.Conv2d(in_dim, out_dim // 2, 1),
		nn.GroupNorm(8, out_dim // 2),
		nn.ReLU(),
		nn.Conv2d(out_dim // 2, out_dim // 2, 1),
		nn.GroupNorm(8, out_dim // 2),
		nn.ReLU(),
		nn.Conv2d(out_dim // 2, out_dim, 1),
		nn.GroupNorm(8, out_dim),
		nn.ReLU(),
	)
	return net


def get_postpool(in_dim, out_dim):
	"""Linear layers in PointNet after max pooling

	Args:
		in_dim: Number of input channels
		out_dim: Number of output channels. Typically smaller than in_dim

	"""
	net = nn.Sequential(
		nn.Conv1d(in_dim, in_dim, 1),
		nn.GroupNorm(8, in_dim),
		nn.ReLU(),
		nn.Conv1d(in_dim, out_dim, 1),
		nn.GroupNorm(8, out_dim),
		nn.ReLU(),
		nn.Conv1d(out_dim, out_dim, 1),
	)

	return net


class PPFNet(nn.Module):
	"""Feature extraction Module that extracts hybrid features"""
	def __init__(self, features=['ppf', 'dxyz', 'xyz'], emb_dims=96, radius=0.3, num_neighbors=64):
		super().__init__()

		self._logger = logging.getLogger(self.__class__.__name__)
		self._logger.info('Using early fusion, feature dim = {}'.format(emb_dims))
		self.radius = radius
		self.n_sample = num_neighbors

		self.features = sorted(features, key=lambda f: _raw_features_order[f])
		self._logger.info('Feature extraction using features {}'.format(', '.join(self.features)))

		# Layers
		raw_dim = np.sum([_raw_features_sizes[f] for f in self.features])  # number of channels after concat
		self.prepool = get_prepool(raw_dim, emb_dims * 2)
		self.postpool = get_postpool(emb_dims * 2, emb_dims)

	def forward(self, xyz, normals):
		"""Forward pass of the feature extraction network

		Args:
			xyz: (B, N, 3)
			normals: (B, N, 3)

		Returns:
			cluster features (B, N, C)

		"""
		features = sample_and_group_multi(-1, self.radius, self.n_sample, xyz, normals)
		features['xyz'] = features['xyz'][:, :, None, :]

		# Gate and concat
		concat = []
		for i in range(len(self.features)):
			f = self.features[i]
			expanded = (features[f]).expand(-1, -1, self.n_sample, -1)
			concat.append(expanded)
		fused_input_feat = torch.cat(concat, -1)

		# Prepool_FC, pool, postpool-FC
		new_feat = fused_input_feat.permute(0, 3, 2, 1)  # [B, 10, n_sample, N]
		new_feat = self.prepool(new_feat)

		pooled_feat = torch.max(new_feat, 2)[0]  # Max pooling (B, C, N)

		post_feat = self.postpool(pooled_feat)  # Post pooling dense layers
		cluster_feat = post_feat.permute(0, 2, 1)
		cluster_feat = cluster_feat / torch.norm(cluster_feat, dim=-1, keepdim=True)

		return cluster_feat  # (B, N, C)