R3PM-Net / thirdparty /learning3d /ops /data_utils.py
YasiiKB's picture
initial commit
97aa5af verified
import torch
def mean_shift(template, source, p0_zero_mean, p1_zero_mean):
template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) # [B, 3, 3]
source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) # [B, 3, 3]
if p0_zero_mean:
p0_m = template.mean(dim=1) # [B, N, 3] -> [B, 3]
template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2)
one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean) # (Bx1x4)
template_mean = torch.cat([template_mean, one_], dim=1)
template = template - p0_m.unsqueeze(1)
# else:
# q0 = template
if p1_zero_mean:
#print(numpy.any(numpy.isnan(p1.numpy())))
p1_m = source.mean(dim=1) # [B, N, 3] -> [B, 3]
source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2)
one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean) # (Bx1x4)
source_mean = torch.cat([source_mean, one_], dim=1)
source = source - p1_m.unsqueeze(1)
# else:
# q1 = source
return template, source, template_mean, source_mean
def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean):
#output' = trans(p0_m) * output * trans(-p1_m)
# = [I, p0_m;] * [R, t;] * [I, -p1_m;]
# [0, 1 ] [0, 1 ] [0, 1 ]
est_g = result['est_T']
if p0_zero_mean:
est_g = a0.to(est_g).bmm(est_g)
if p1_zero_mean:
est_g = est_g.bmm(a1.to(est_g))
result['est_T'] = est_g
est_gs = result['est_T_series'] # [M, B, 4, 4]
if p0_zero_mean:
est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs)
if p1_zero_mean:
est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs))
result['est_T_series'] = est_gs
return result