""" 3-d rotation group and corresponding Lie algebra """ import torch from . import sinc from .sinc import sinc1, sinc2, sinc3 def cross_prod(x, y): z = torch.cross(x.view(-1, 3), y.view(-1, 3), dim=1).view_as(x) return z def liebracket(x, y): return cross_prod(x, y) def mat(x): # size: [*, 3] -> [*, 3, 3] x_ = x.view(-1, 3) x1, x2, x3 = x_[:, 0], x_[:, 1], x_[:, 2] O = torch.zeros_like(x1) X = torch.stack(( torch.stack((O, -x3, x2), dim=1), torch.stack((x3, O, -x1), dim=1), torch.stack((-x2, x1, O), dim=1)), dim=1) return X.view(*(x.size()[0:-1]), 3, 3) def vec(X): X_ = X.view(-1, 3, 3) x1, x2, x3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] x = torch.stack((x1, x2, x3), dim=1) return x.view(*X.size()[0:-2], 3) def genvec(): return torch.eye(3) def genmat(): return mat(genvec()) def RodriguesRotation(x): # for autograd w = x.view(-1, 3) t = w.norm(p=2, dim=1).view(-1, 1, 1) W = mat(w) S = W.bmm(W) I = torch.eye(3).to(w) # Rodrigues' rotation formula. #R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); #R = eye(3) + sinc1(t)*W + sinc2(t)*S R = I + sinc.Sinc1(t)*W + sinc.Sinc2(t)*S return R.view(*(x.size()[0:-1]), 3, 3) def exp(x): w = x.view(-1, 3) t = w.norm(p=2, dim=1).view(-1, 1, 1) W = mat(w) S = W.bmm(W) I = torch.eye(3).to(w) # Rodrigues' rotation formula. #R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); #R = eye(3) + sinc1(t)*W + sinc2(t)*S R = I + sinc1(t)*W + sinc2(t)*S return R.view(*(x.size()[0:-1]), 3, 3) def inverse(g): R = g.view(-1, 3, 3) Rt = R.transpose(1, 2) return Rt.view_as(g) def btrace(X): # batch-trace: [B, N, N] -> [B] n = X.size(-1) X_ = X.view(-1, n, n) tr = torch.zeros(X_.size(0)).to(X) for i in range(tr.size(0)): m = X_[i, :, :] tr[i] = torch.trace(m) return tr.view(*(X.size()[0:-2])) def log(g): eps = 1.0e-7 R = g.view(-1, 3, 3) tr = btrace(R) c = (tr - 1) / 2 t = torch.acos(c) sc = sinc1(t) idx0 = (torch.abs(sc) <= eps) idx1 = (torch.abs(sc) > eps) sc = sc.view(-1, 1, 1) X = torch.zeros_like(R) if idx1.any(): X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2*sc[idx1]) if idx0.any(): # t[idx0] == math.pi t2 = t[idx0] ** 2 A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2 aw1 = torch.sqrt(A[:, 0, 0]) aw2 = torch.sqrt(A[:, 1, 1]) aw3 = torch.sqrt(A[:, 2, 2]) sgn_3 = torch.sign(A[:, 0, 2]) sgn_3[sgn_3 == 0] = 1 sgn_23 = torch.sign(A[:, 1, 2]) sgn_23[sgn_23 == 0] = 1 sgn_2 = sgn_23 * sgn_3 w1 = aw1 w2 = aw2 * sgn_2 w3 = aw3 * sgn_3 w = torch.stack((w1, w2, w3), dim=-1) W = mat(w) X[idx0] = W x = vec(X.view_as(g)) return x def transform(g, a): # g in SO(3): * x 3 x 3 # a in R^3: * x 3[x N] if len(g.size()) == len(a.size()): b = g.matmul(a) else: b = g.matmul(a.unsqueeze(-1)).squeeze(-1) return b def group_prod(g, h): # g, h : SO(3) g1 = g.matmul(h) return g1 def vecs_Xg_ig(x): """ Vi = vec(dg/dxi * inv(g)), where g = exp(x) (== [Ad(exp(x))] * vecs_ig_Xg(x)) """ t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) X = mat(x) S = X.bmm(X) #B = x.view(-1,3,1).bmm(x.view(-1,1,3)) # B = x*x' I = torch.eye(3).to(X) #V = sinc1(t)*eye(3) + sinc2(t)*X + sinc3(t)*B #V = eye(3) + sinc2(t)*X + sinc3(t)*S V = I + sinc2(t)*X + sinc3(t)*S return V.view(*(x.size()[0:-1]), 3, 3) def inv_vecs_Xg_ig(x): """ H = inv(vecs_Xg_ig(x)) """ t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) X = mat(x) S = X.bmm(X) I = torch.eye(3).to(x) e = 0.01 eta = torch.zeros_like(t) s = (t < e) c = (s == 0) t2 = t[s] ** 2 eta[s] = ((t2/40 + 1)*t2/42 + 1)*t2/720 + 1/12 # O(t**8) eta[c] = (1 - (t[c]/2) / torch.tan(t[c]/2)) / (t[c]**2) H = I - 1/2*X + eta*S return H.view(*(x.size()[0:-1]), 3, 3) class ExpMap(torch.autograd.Function): """ Exp: so(3) -> SO(3) """ @staticmethod def forward(ctx, x): """ Exp: R^3 -> M(3), size: [B, 3] -> [B, 3, 3], or [B, 1, 3] -> [B, 1, 3, 3] """ ctx.save_for_backward(x) g = exp(x) return g @staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors g = exp(x) gen_k = genmat().to(x) #gen_1 = gen_k[0, :, :] #gen_2 = gen_k[1, :, :] #gen_3 = gen_k[2, :, :] # Let z = f(g) = f(exp(x)) # dz = df/dgij * dgij/dxk * dxk # = df/dgij * (d/dxk)[exp(x)]_ij * dxk # = df/dgij * [gen_k*g]_ij * dxk dg = gen_k.matmul(g.view(-1, 1, 3, 3)) # (k, i, j) dg = dg.to(grad_output) go = grad_output.contiguous().view(-1, 1, 3, 3) dd = go * dg grad_input = dd.sum(-1).sum(-1) return grad_input Exp = ExpMap.apply #EOF