File size: 5,805 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import torch
import torch.nn as nn
import torch.nn.functional as F
from .pointnet import PointNet
from .pooling import Pooling
from .. ops import data_utils
from .. ops import se3, so3, invmat
class PointNetLK(nn.Module):
def __init__(self, feature_model=PointNet(), delta=1.0e-2, learn_delta=False, xtol=1.0e-7, p0_zero_mean=True, p1_zero_mean=True, pooling='max'):
super().__init__()
self.feature_model = feature_model
self.pooling = Pooling(pooling)
self.inverse = invmat.InvMatrix.apply
self.exp = se3.Exp # [B, 6] -> [B, 4, 4]
self.transform = se3.transform # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
w1, w2, w3, v1, v2, v3 = delta, delta, delta, delta, delta, delta
twist = torch.Tensor([w1, w2, w3, v1, v2, v3])
self.dt = torch.nn.Parameter(twist.view(1, 6), requires_grad=learn_delta)
# results
self.last_err = None
self.g_series = None # for debug purpose
self.prev_r = None
self.g = None # estimation result
self.itr = 0
self.xtol = xtol
self.p0_zero_mean = p0_zero_mean
self.p1_zero_mean = p1_zero_mean
def forward(self, template, source, maxiter=10):
template, source, template_mean, source_mean = data_utils.mean_shift(template, source,
self.p0_zero_mean, self.p1_zero_mean)
result = self.iclk(template, source, maxiter)
result = data_utils.postprocess_data(result, template, source, template_mean, source_mean,
self.p0_zero_mean, self.p1_zero_mean)
return result
def iclk(self, template, source, maxiter):
batch_size = template.size(0)
est_T0 = torch.eye(4).to(template).view(1, 4, 4).expand(template.size(0), 4, 4).contiguous()
est_T = est_T0
self.est_T_series = torch.zeros(maxiter+1, *est_T0.size(), dtype=est_T0.dtype)
self.est_T_series[0] = est_T0.clone()
training = self.handle_batchNorm(template, source)
# re-calc. with current modules
template_features = self.pooling(self.feature_model(template)) # [B, N, 3] -> [B, K]
# approx. J by finite difference
dt = self.dt.to(template).expand(batch_size, 6)
J = self.approx_Jic(template, template_features, dt)
self.last_err = None
pinv = self.compute_inverse_jacobian(J, template_features, source)
if pinv == {}:
result = {'est_R': est_T[:,0:3,0:3],
'est_t': est_T[:,0:3,3],
'est_T': est_T,
'r': None,
'transformed_source': self.transform(est_T.unsqueeze(1), source),
'itr': 1,
'est_T_series': self.est_T_series}
return result
itr = 0
r = None
for itr in range(maxiter):
self.prev_r = r
transformed_source = self.transform(est_T.unsqueeze(1), source) # [B, 1, 4, 4] x [B, N, 3] -> [B, N, 3]
source_features = self.pooling(self.feature_model(transformed_source)) # [B, N, 3] -> [B, K]
r = source_features - template_features
pose = -pinv.bmm(r.unsqueeze(-1)).view(batch_size, 6)
check = pose.norm(p=2, dim=1, keepdim=True).max()
if float(check) < self.xtol:
if itr == 0:
self.last_err = 0 # no update.
break
est_T = self.update(est_T, pose)
self.est_T_series[itr+1] = est_T.clone()
rep = len(range(itr, maxiter))
self.est_T_series[(itr+1):] = est_T.clone().unsqueeze(0).repeat(rep, 1, 1, 1)
self.feature_model.train(training)
self.est_T = est_T
result = {'est_R': est_T[:,0:3,0:3],
'est_t': est_T[:,0:3,3],
'est_T': est_T,
'r': r,
'transformed_source': self.transform(est_T.unsqueeze(1), source),
'itr': itr+1,
'est_T_series': self.est_T_series}
return result
def update(self, g, dx):
# [B, 4, 4] x [B, 6] -> [B, 4, 4]
dg = self.exp(dx)
return dg.matmul(g)
def approx_Jic(self, template, template_features, dt):
# p0: [B, N, 3], Variable
# f0: [B, K], corresponding feature vector
# dt: [B, 6], Variable
# Jk = (feature_model(p(-delta[k], p0)) - f0) / delta[k]
batch_size = template.size(0)
num_points = template.size(1)
# compute transforms
transf = torch.zeros(batch_size, 6, 4, 4).to(template)
for b in range(template.size(0)):
d = torch.diag(dt[b, :]) # [6, 6]
D = self.exp(-d) # [6, 4, 4]
transf[b, :, :, :] = D[:, :, :]
transf = transf.unsqueeze(2).contiguous() # [B, 6, 1, 4, 4]
p = self.transform(transf, template.unsqueeze(1)) # x [B, 1, N, 3] -> [B, 6, N, 3]
#f0 = self.feature_model(p0).unsqueeze(-1) # [B, K, 1]
template_features = template_features.unsqueeze(-1) # [B, K, 1]
f = self.pooling(self.feature_model(p.view(-1, num_points, 3))).view(batch_size, 6, -1).transpose(1, 2) # [B, K, 6]
df = template_features - f # [B, K, 6]
J = df / dt.unsqueeze(1)
return J
def compute_inverse_jacobian(self, J, template_features, source):
# compute pinv(J) to solve J*x = -r
try:
Jt = J.transpose(1, 2) # [B, 6, K]
H = Jt.bmm(J) # [B, 6, 6]
B = self.inverse(H)
pinv = B.bmm(Jt) # [B, 6, K]
return pinv
except RuntimeError as err:
# singular...?
self.last_err = err
g = torch.eye(4).to(source).view(1, 4, 4).expand(source.size(0), 4, 4).contiguous()
#print(err)
# Perhaps we can use MP-inverse, but,...
# probably, self.dt is way too small...
source_features = self.pooling(self.feature_model(source)) # [B, N, 3] -> [B, K]
r = source_features - template_features
self.feature_model.train(self.feature_model.training)
return {}
def handle_batchNorm(self, template, source):
training = self.feature_model.training
if training:
# first, update BatchNorm modules
template_features, source_features = self.pooling(self.feature_model(template)), self.pooling(self.feature_model(source))
self.feature_model.eval() # and fix them.
return training
if __name__ == '__main__':
template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
pn = PointNet()
net = PointNetLK(pn)
result = net(template, source)
import ipdb; ipdb.set_trace() |