File size: 1,764 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch

def mean_shift(template, source, p0_zero_mean, p1_zero_mean):
	template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) 		# [B, 3, 3]
	source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) 			# [B, 3, 3]
	
	if p0_zero_mean:
		p0_m = template.mean(dim=1) # [B, N, 3] -> [B, 3]
		template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2)
		one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean)    # (Bx1x4)
		template_mean = torch.cat([template_mean, one_], dim=1)
		template = template - p0_m.unsqueeze(1)
	# else:
		# q0 = template

	if p1_zero_mean:
		#print(numpy.any(numpy.isnan(p1.numpy())))
		p1_m = source.mean(dim=1) # [B, N, 3] -> [B, 3]
		source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2)
		one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean)    # (Bx1x4)
		source_mean = torch.cat([source_mean, one_], dim=1)
		source = source - p1_m.unsqueeze(1)
	# else:
		# q1 = source
	return template, source, template_mean, source_mean

def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean):
	#output' = trans(p0_m) * output * trans(-p1_m)
	#        = [I, p0_m;] * [R, t;] * [I, -p1_m;]
	#          [0, 1    ]   [0, 1 ]   [0,  1    ]
	est_g = result['est_T']
	if p0_zero_mean:
		est_g = a0.to(est_g).bmm(est_g)
	if p1_zero_mean:
		est_g = est_g.bmm(a1.to(est_g))
	result['est_T'] = est_g

	est_gs = result['est_T_series'] # [M, B, 4, 4]
	if p0_zero_mean:
		est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs)
	if p1_zero_mean:
		est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs))
	result['est_T_series'] = est_gs

	return result