YasiiKB's picture
initial commit
97aa5af verified
""" 3-d rigid body transfomation group and corresponding Lie algebra. """
import torch
from .sinc import sinc1, sinc2, sinc3
from . import so3
def twist_prod(x, y):
x_ = x.view(-1, 6)
y_ = y.view(-1, 6)
xw, xv = x_[:, 0:3], x_[:, 3:6]
yw, yv = y_[:, 0:3], y_[:, 3:6]
zw = so3.cross_prod(xw, yw)
zv = so3.cross_prod(xw, yv) + so3.cross_prod(xv, yw)
z = torch.cat((zw, zv), dim=1)
return z.view_as(x)
def liebracket(x, y):
return twist_prod(x, y)
def mat(x):
# size: [*, 6] -> [*, 4, 4]
x_ = x.view(-1, 6)
w1, w2, w3 = x_[:, 0], x_[:, 1], x_[:, 2]
v1, v2, v3 = x_[:, 3], x_[:, 4], x_[:, 5]
O = torch.zeros_like(w1)
X = torch.stack((
torch.stack(( O, -w3, w2, v1), dim=1),
torch.stack(( w3, O, -w1, v2), dim=1),
torch.stack((-w2, w1, O, v3), dim=1),
torch.stack(( O, O, O, O), dim=1)), dim=1)
return X.view(*(x.size()[0:-1]), 4, 4)
def vec(X):
X_ = X.view(-1, 4, 4)
w1, w2, w3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0]
v1, v2, v3 = X_[:, 0, 3], X_[:, 1, 3], X_[:, 2, 3]
x = torch.stack((w1, w2, w3, v1, v2, v3), dim=1)
return x.view(*X.size()[0:-2], 6)
def genvec():
return torch.eye(6)
def genmat():
return mat(genvec())
def exp(x):
x_ = x.view(-1, 6)
w, v = x_[:, 0:3], x_[:, 3:6]
t = w.norm(p=2, dim=1).view(-1, 1, 1)
W = so3.mat(w)
S = W.bmm(W)
I = torch.eye(3).to(w)
# Rodrigues' rotation formula.
#R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w');
# = eye(3) + sinc1(t)*W + sinc2(t)*S
R = I + sinc1(t)*W + sinc2(t)*S
#V = sinc1(t)*eye(3) + sinc2(t)*W + sinc3(t)*(w*w')
# = eye(3) + sinc2(t)*W + sinc3(t)*S
V = I + sinc2(t)*W + sinc3(t)*S
p = V.bmm(v.contiguous().view(-1, 3, 1))
z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(x_.size(0), 1, 1).to(x)
Rp = torch.cat((R, p), dim=2)
g = torch.cat((Rp, z), dim=1)
return g.view(*(x.size()[0:-1]), 4, 4)
def inverse(g):
g_ = g.view(-1, 4, 4)
R = g_[:, 0:3, 0:3]
p = g_[:, 0:3, 3]
Q = R.transpose(1, 2)
q = -Q.matmul(p.unsqueeze(-1))
z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(g_.size(0), 1, 1).to(g)
Qq = torch.cat((Q, q), dim=2)
ig = torch.cat((Qq, z), dim=1)
return ig.view(*(g.size()[0:-2]), 4, 4)
def log(g):
g_ = g.view(-1, 4, 4)
R = g_[:, 0:3, 0:3]
p = g_[:, 0:3, 3]
w = so3.log(R)
H = so3.inv_vecs_Xg_ig(w)
v = H.bmm(p.contiguous().view(-1, 3, 1)).view(-1, 3)
x = torch.cat((w, v), dim=1)
return x.view(*(g.size()[0:-2]), 6)
def transform(g, a):
# g : SE(3), * x 4 x 4
# a : R^3, * x 3[x N]
g_ = g.view(-1, 4, 4)
R = g_[:, 0:3, 0:3].contiguous().view(*(g.size()[0:-2]), 3, 3)
p = g_[:, 0:3, 3].contiguous().view(*(g.size()[0:-2]), 3)
if len(g.size()) == len(a.size()):
b = R.matmul(a) + p.unsqueeze(-1)
else:
b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p
return b
def group_prod(g, h):
# g, h : SE(3)
g1 = g.matmul(h)
return g1
class ExpMap(torch.autograd.Function):
""" Exp: se(3) -> SE(3)
"""
@staticmethod
def forward(ctx, x):
""" Exp: R^6 -> M(4),
size: [B, 6] -> [B, 4, 4],
or [B, 1, 6] -> [B, 1, 4, 4]
"""
ctx.save_for_backward(x)
g = exp(x)
return g
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
g = exp(x)
gen_k = genmat().to(x)
# Let z = f(g) = f(exp(x))
# dz = df/dgij * dgij/dxk * dxk
# = df/dgij * (d/dxk)[exp(x)]_ij * dxk
# = df/dgij * [gen_k*g]_ij * dxk
dg = gen_k.matmul(g.view(-1, 1, 4, 4))
# (k, i, j)
dg = dg.to(grad_output)
go = grad_output.contiguous().view(-1, 1, 4, 4)
dd = go * dg
grad_input = dd.sum(-1).sum(-1)
return grad_input
Exp = ExpMap.apply
#EOF