| import torch |
|
|
| def mean_shift(template, source, p0_zero_mean, p1_zero_mean): |
| template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) |
| source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) |
| |
| if p0_zero_mean: |
| p0_m = template.mean(dim=1) |
| template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2) |
| one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean) |
| template_mean = torch.cat([template_mean, one_], dim=1) |
| template = template - p0_m.unsqueeze(1) |
| |
| |
|
|
| if p1_zero_mean: |
| |
| p1_m = source.mean(dim=1) |
| source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2) |
| one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean) |
| source_mean = torch.cat([source_mean, one_], dim=1) |
| source = source - p1_m.unsqueeze(1) |
| |
| |
| return template, source, template_mean, source_mean |
|
|
| def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean): |
| |
| |
| |
| est_g = result['est_T'] |
| if p0_zero_mean: |
| est_g = a0.to(est_g).bmm(est_g) |
| if p1_zero_mean: |
| est_g = est_g.bmm(a1.to(est_g)) |
| result['est_T'] = est_g |
|
|
| est_gs = result['est_T_series'] |
| if p0_zero_mean: |
| est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs) |
| if p1_zero_mean: |
| est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs)) |
| result['est_T_series'] = est_gs |
|
|
| return result |
|
|