File size: 12,819 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from thirdparty.learning3d.utils import square_distance, angle_difference
from thirdparty.learning3d.ops.transform_functions import convert2transformation

_EPS = 1e-5  # To prevent division by zero

class ParameterPredictionNet(nn.Module):
	def __init__(self, weights_dim):
		"""PointNet based Parameter prediction network

		Args:
			weights_dim: Number of weights to predict (excluding beta), should be something like
						 [3], or [64, 3], for 3 types of features
		"""

		super().__init__()

		self._logger = logging.getLogger(self.__class__.__name__)

		self.weights_dim = weights_dim

		# Pointnet
		self.prepool = nn.Sequential(
			nn.Conv1d(4, 64, 1),
			nn.GroupNorm(8, 64),
			nn.ReLU(),

			nn.Conv1d(64, 64, 1),
			nn.GroupNorm(8, 64),
			nn.ReLU(),

			nn.Conv1d(64, 64, 1),
			nn.GroupNorm(8, 64),
			nn.ReLU(),

			nn.Conv1d(64, 128, 1),
			nn.GroupNorm(8, 128),
			nn.ReLU(),

			nn.Conv1d(128, 1024, 1),
			nn.GroupNorm(16, 1024),
			nn.ReLU(),
		)
		self.pooling = nn.AdaptiveMaxPool1d(1)
		self.postpool = nn.Sequential(
			nn.Linear(1024, 512),
			nn.GroupNorm(16, 512),
			nn.ReLU(),

			nn.Linear(512, 256),
			nn.GroupNorm(16, 256),
			nn.ReLU(),

			nn.Linear(256, 2 + np.prod(weights_dim)),
		)

		self._logger.info('Predicting weights with dim {}.'.format(self.weights_dim))

	def forward(self, x):
		""" Returns alpha, beta, and gating_weights (if needed)

		Args:
			x: List containing two point clouds, x[0] = src (B, J, 3), x[1] = ref (B, K, 3)

		Returns:
			beta, alpha, weightings
		"""
		# X and Y concatenated
		src_padded = F.pad(x[0], (0, 1), mode='constant', value=0)
		ref_padded = F.pad(x[1], (0, 1), mode='constant', value=1)
		concatenated = torch.cat([src_padded, ref_padded], dim=1)

		prepool_feat = self.prepool(concatenated.permute(0, 2, 1)) 
		pooled = torch.flatten(self.pooling(prepool_feat), start_dim=-2)
		raw_weights = self.postpool(pooled)

		# softplus to ensure positivity
		beta = F.softplus(raw_weights[:, 0])
		alpha = F.softplus(raw_weights[:, 1])

		return beta, alpha



def to_numpy(tensor):
	"""Wrapper around .detach().cpu().numpy() """
	if isinstance(tensor, torch.Tensor):
		return tensor.detach().cpu().numpy()
	elif isinstance(tensor, np.ndarray):
		return tensor
	else:
		raise NotImplementedError


def se3_transform(g, a, normals=None):
	""" Applies the SE3 transform

	Args:
		g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
		a: Points to be transformed (N, 3) or (B, N, 3)
		normals: (Optional). If provided, normals will be transformed

	Returns:
		transformed points of size (N, 3) or (B, N, 3)

	"""
	R = g[..., :3, :3]  # (B, 3, 3)
	p = g[..., :3, 3]  # (B, 3)

	if len(g.size()) == len(a.size()):
		b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
	else:
		raise NotImplementedError
		b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p  # No batch. Not checked

	if normals is not None:
		rotated_normals = normals @ R.transpose(-1, -2)
		return b, rotated_normals

	else:
		return b

def match_features(feat_src, feat_ref, metric='l2'):
	""" Compute pairwise distance between features

	Args:
		feat_src: (B, J, C)
		feat_ref: (B, K, C)
		metric: either 'angle' or 'l2' (squared euclidean)

	Returns:
		Matching matrix (B, J, K). i'th row describes how well the i'th point
		 in the src agrees with every point in the ref.
	"""
	if feat_src.shape[-1] != feat_ref.shape[-1]:
		if feat_src.shape[-1] > feat_ref.shape[-1]:
			feat_src = feat_src[:,:,:feat_ref.shape[-1]]
		elif feat_src.shape[-1] < feat_ref.shape[-1]:
			feat_ref = feat_ref[:,:,:feat_src.shape[-1]]

	assert feat_src.shape[-1] == feat_ref.shape[-1]

	if metric == 'l2':
		dist_matrix = square_distance(feat_src, feat_ref)
	elif metric == 'angle':
		feat_src_norm = feat_src / (torch.norm(feat_src, dim=-1, keepdim=True) + _EPS)
		feat_ref_norm = feat_ref / (torch.norm(feat_ref, dim=-1, keepdim=True) + _EPS)

		dist_matrix = angle_difference(feat_src_norm, feat_ref_norm)
	else:
		raise NotImplementedError

	return dist_matrix


def sinkhorn(log_alpha, n_iters: int = 5, slack: bool = True, eps: float = -1) -> torch.Tensor:
	""" Run sinkhorn iterations to generate a near doubly stochastic matrix, where each row or column sum to <=1

	Args:
		log_alpha: log of positive matrix to apply sinkhorn normalization (B, J, K)
		n_iters (int): Number of normalization iterations
		slack (bool): Whether to include slack row and column
		eps: eps for early termination (Used only for handcrafted RPM). Set to negative to disable.

	Returns:
		log(perm_matrix): Doubly stochastic matrix (B, J, K)

	Modified from original source taken from:
		Learning Latent Permutations with Gumbel-Sinkhorn Networks
		https://github.com/HeddaCohenIndelman/Learning-Gumbel-Sinkhorn-Permutations-w-Pytorch
	"""

	# Sinkhorn iterations
	prev_alpha = None
	if slack:
		zero_pad = nn.ZeroPad2d((0, 1, 0, 1))
		log_alpha_padded = zero_pad(log_alpha[:, None, :, :])

		log_alpha_padded = torch.squeeze(log_alpha_padded, dim=1)

		for i in range(n_iters):
			# Row normalization
			log_alpha_padded = torch.cat((
					log_alpha_padded[:, :-1, :] - (torch.logsumexp(log_alpha_padded[:, :-1, :], dim=2, keepdim=True)),
					log_alpha_padded[:, -1, None, :]),  # Don't normalize last row
				dim=1)

			# Column normalization
			log_alpha_padded = torch.cat((
					log_alpha_padded[:, :, :-1] - (torch.logsumexp(log_alpha_padded[:, :, :-1], dim=1, keepdim=True)),
					log_alpha_padded[:, :, -1, None]),  # Don't normalize last column
				dim=2)

			if eps > 0:
				if prev_alpha is not None:
					abs_dev = torch.abs(torch.exp(log_alpha_padded[:, :-1, :-1]) - prev_alpha)
					if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
						break
				prev_alpha = torch.exp(log_alpha_padded[:, :-1, :-1]).clone()

		log_alpha = log_alpha_padded[:, :-1, :-1]
	else:
		for i in range(n_iters):
			# Row normalization (i.e. each row sum to 1)
			log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=2, keepdim=True))

			# Column normalization (i.e. each column sum to 1)
			log_alpha = log_alpha - (torch.logsumexp(log_alpha, dim=1, keepdim=True))

			if eps > 0:
				if prev_alpha is not None:
					abs_dev = torch.abs(torch.exp(log_alpha) - prev_alpha)
					if torch.max(torch.sum(abs_dev, dim=[1, 2])) < eps:
						break
				prev_alpha = torch.exp(log_alpha).clone()

	return log_alpha


def compute_rigid_transform(a: torch.Tensor, b: torch.Tensor, weights: torch.Tensor):
	"""Compute rigid transforms between two point sets

	Args:
		a (torch.Tensor): (B, M, 3) points
		b (torch.Tensor): (B, N, 3) points
		weights (torch.Tensor): (B, M)

	Returns:
		Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
	"""

	weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
	centroid_a = torch.sum(a * weights_normalized, dim=1) 
	centroid_b = torch.sum(b * weights_normalized, dim=1) 
	a_centered = a - centroid_a[:, None, :]  
	b_centered = b - centroid_b[:, None, :]
	cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)  

	# Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
	# and choose based on determinant to avoid flips
	u, s, v = torch.svd(cov, some=False, compute_uv=True) 
	rot_mat_pos = v @ u.transpose(-1, -2)  
	v_neg = v.clone()
	v_neg[:, :, 2] *= -1
	rot_mat_neg = v_neg @ u.transpose(-1, -2)
	rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
	assert torch.all(torch.det(rot_mat) > 0)

	# Compute translation (uncenter centroid)
	translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]   

	transform = torch.cat((rot_mat, translation), dim=2)  
	return transform

class R3PMNet(nn.Module):
	def __init__(self, feature_model):

		super().__init__()

		self.add_slack = True
		self.num_sk_iter = 5

		self.weights_net = ParameterPredictionNet(weights_dim=[0])
		self.feat_extractor = feature_model

	def compute_affinity(self, beta, feat_distance, alpha=0.5):
		"""Compute logarithm of Initial match matrix values, i.e. log(m_jk)"""
		if isinstance(alpha, float):
			hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha)
		else:
			hybrid_affinity = -beta[:, None, None] * (feat_distance - alpha[:, None, None])
		return hybrid_affinity

	@staticmethod
	def split_normals(data):
		if data.shape[2] == 6:
			xyz, normals = data[:, :, :3], data[:, :, 3:6]
		elif data.shape[2] == 3:
			xyz, normals = data, torch.zeros(data.shape).to(data.device)
		return xyz, normals

	def spam(self, xyz_template, norm_template, xyz_source, norm_source):
		self.beta, self.alpha = self.weights_net([xyz_source, xyz_template])

		try: # R3PMNET feature extractor
			self.feat_source = self.feat_extractor(xyz_source)
			self.feat_template = self.feat_extractor(xyz_template)
		except: 
			self.feat_source = self.feat_extractor(xyz_source, norm_source)
			self.feat_template = self.feat_extractor(xyz_template, norm_template)

		feat_distance = match_features(self.feat_source, self.feat_template)  
		self.affinity = self.compute_affinity(self.beta, feat_distance, alpha=self.alpha) 

		# Compute weighted coordinates
		log_perm_matrix = sinkhorn(self.affinity, n_iters=self.num_sk_iter, slack=self.add_slack) 
		self.perm_matrix = torch.exp(log_perm_matrix) 

		try: # R3PMNET features
			weighted_template = self.perm_matrix @ xyz_template[:,:self.perm_matrix.shape[1]] / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS)
		except: 
			weighted_template = self.perm_matrix @ xyz_template / (torch.sum(self.perm_matrix, dim=2, keepdim=True) + _EPS) 
		return weighted_template

	def forward(self, template, source, max_iterations: int = 1):
		"""Forward pass for R3PM-Net

		Args:
			data: Dict containing the following fields:
					'points_src': Source points (B, J, 6)
					'points_ref': Reference points (B, K, 6)
			num_iter (int): Number of iterations. Recommended to be 2 for training

		Returns:
			transform: Transform to apply to source points such that they align to reference
			src_transformed: Transformed source points
		"""

		xyz_template, norm_template = self.split_normals(template)
		xyz_source, norm_source = self.split_normals(source)

		xyz_source_t, norm_source_t = xyz_source, norm_source  # a copy of source to apply transformation to 

		transforms = []
		all_gamma, all_perm_matrices, all_weighted_template = [], [], []
		all_beta, all_alpha = [], []

		for i in range(max_iterations):
			weighted_template = self.spam(xyz_template, norm_template, xyz_source_t, norm_source_t)		# Finding better correspondences after each iteration.
			
			# Compute transform and transform points
			try: # R3PMNET features
				transform = compute_rigid_transform(xyz_source[:,:weighted_template.shape[1]], weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
				xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source[:,:weighted_template.shape[1]], norm_source)			# Apply transformation to original source.	
			except:
				transform = compute_rigid_transform(xyz_source_t, weighted_template, weights=torch.sum(self.perm_matrix, dim=2))
				xyz_source_t, norm_source_t = se3_transform(transform.detach(), xyz_source, norm_source)			# Apply transformation to original source.
			

			transforms.append(transform)
			all_gamma.append(torch.exp(self.affinity))
			all_perm_matrices.append(self.perm_matrix)
			all_weighted_template.append(weighted_template)
			all_beta.append(to_numpy(self.beta))
			all_alpha.append(to_numpy(self.alpha))

		est_T = convert2transformation(transforms[max_iterations-1][:, :3, :3], transforms[max_iterations-1][:, :3, 3])
		transformed_source = torch.bmm(est_T[:, :3, :3], source[:,:,:3].permute(0, 2, 1)).permute(0, 2, 1) + est_T[:, :3, 3].unsqueeze(1)

		try:  # for training
			result = {'est_R': est_T[:, :3, :3],	# source -> template
					'est_t': est_T[:, :3,  3],		# source -> template
					'est_T': est_T,			# source -> template
					'r': self.feat_template - self.feat_source,
					'transformed_source': transformed_source}
		except RuntimeError: 
			result = {'est_R': est_T[:, :3, :3],	# source -> template
					'est_t': est_T[:, :3,  3],		# source -> template
					'est_T': est_T,			# source -> template
					'transformed_source': transformed_source}
			
		result['perm_matrices_init'] = all_gamma
		result['perm_matrices'] = all_perm_matrices
		result['weighted_template'] = all_weighted_template
		result['beta'] = np.stack(all_beta, axis=0)
		result['alpha'] = np.stack(all_alpha, axis=0)
		result['transforms'] = transforms

		return result


if __name__ == '__main__':
	template, source = torch.rand(10,1024,6), torch.rand(10,1024,6)
	
	net = R3PMNet()
	result = net(template, source)
	import ipdb; ipdb.set_trace()