Spaces:
Sleeping
Sleeping
| import torch | |
| import lietorch | |
| from lietorch import SO3, RxSO3, SE3, Sim3 | |
| from gradcheck import gradcheck, get_analytical_jacobian | |
| ### forward tests ### | |
| def make_homogeneous(p): | |
| return torch.cat([p, torch.ones_like(p[...,:1])], dim=-1) | |
| def matv(A, b): | |
| return torch.matmul(A, b[...,None])[..., 0] | |
| def test_exp_log(Group, device='cuda'): | |
| """ check Log(Exp(x)) == x """ | |
| a = .2*torch.randn(2,3,4,5,6,7,Group.manifold_dim, device=device).double() | |
| b = Group.exp(a).log() | |
| assert torch.allclose(a,b,atol=1e-8), "should be identity" | |
| print("\t-", Group, "Passed exp-log test") | |
| def test_inv(Group, device='cuda'): | |
| """ check X * X^{-1} == 0 """ | |
| X = Group.exp(.1*torch.randn(2,3,4,5,Group.manifold_dim, device=device).double()) | |
| a = (X * X.inv()).log() | |
| assert torch.allclose(a, torch.zeros_like(a), atol=1e-8), "should be 0" | |
| print("\t-", Group, "Passed inv test") | |
| def test_adj(Group, device='cuda'): | |
| """ check X * Exp(a) == Exp(Adj(X,a)) * X 0 """ | |
| X = Group.exp(torch.randn(2,3,4,5, Group.manifold_dim, device=device).double()) | |
| a = torch.randn(2,3,4,5, Group.manifold_dim, device=device).double() | |
| b = X.adj(a) | |
| Y1 = X * Group.exp(a) | |
| Y2 = Group.exp(b) * X | |
| c = (Y1 * Y2.inv()).log() | |
| assert torch.allclose(c, torch.zeros_like(c), atol=1e-8), "should be 0" | |
| print("\t-", Group, "Passed adj test") | |
| def test_act(Group, device='cuda'): | |
| X = Group.exp(torch.randn(1, Group.manifold_dim, device=device).double()) | |
| p = torch.randn(1,3,device=device).double() | |
| p1 = X.act(p) | |
| p2 = matv(X.matrix(), make_homogeneous(p)) | |
| assert torch.allclose(p1, p2[...,:3], atol=1e-8), "should be 0" | |
| print("\t-", Group, "Passed act test") | |
| ### backward tests ### | |
| def test_exp_log_grad(Group, device='cuda', tol=1e-8): | |
| D = Group.manifold_dim | |
| def fn(a): | |
| return Group.exp(a).log() | |
| a = torch.zeros(1, Group.manifold_dim, requires_grad=True, device=device).double() | |
| analytical, reentrant, correct_grad_sizes, correct_grad_types = \ | |
| get_analytical_jacobian((a,), fn(a)) | |
| assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol) | |
| a = .2 * torch.randn(1, Group.manifold_dim, requires_grad=True, device=device).double() | |
| analytical, reentrant, correct_grad_sizes, correct_grad_types = \ | |
| get_analytical_jacobian((a,), fn(a)) | |
| assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol) | |
| print("\t-", Group, "Passed eye-grad test") | |
| def test_inv_log_grad(Group, device='cuda', tol=1e-8): | |
| D = Group.manifold_dim | |
| X = Group.exp(.2*torch.randn(1,D,device=device).double()) | |
| def fn(a): | |
| return (Group.exp(a) * X).inv().log() | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
| # assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
| if not torch.allclose(analytical[0], numerical[0], atol=tol): | |
| print(analytical[0]) | |
| print(numerical[0]) | |
| print("\t-", Group, "Passed inv-grad test") | |
| def test_adj_grad(Group, device='cuda'): | |
| D = Group.manifold_dim | |
| X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double()) | |
| def fn(a, b): | |
| return (Group.exp(a) * X).adj(b) | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| b = torch.randn(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
| assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
| print("\t-", Group, "Passed adj-grad test") | |
| def test_adjT_grad(Group, device='cuda'): | |
| D = Group.manifold_dim | |
| X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double()) | |
| def fn(a, b): | |
| return (Group.exp(a) * X).adjT(b) | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| b = torch.randn(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
| assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
| print("\t-", Group, "Passed adjT-grad test") | |
| def test_act_grad(Group, device='cuda'): | |
| D = Group.manifold_dim | |
| X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
| def fn(a, b): | |
| return (X*Group.exp(a)).act(b) | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| b = torch.randn(1, 3, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a, b], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
| assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
| print("\t-", Group, "Passed act-grad test") | |
| def test_matrix_grad(Group, device='cuda'): | |
| D = Group.manifold_dim | |
| X = Group.exp(torch.randn(1, D, device=device).double()) | |
| def fn(a): | |
| return (Group.exp(a) * X).matrix() | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-6) | |
| print("\t-", Group, "Passed matrix-grad test") | |
| def extract_translation_grad(Group, device='cuda'): | |
| """ prototype function """ | |
| D = Group.manifold_dim | |
| X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
| def fn(a): | |
| return (Group.exp(a)*X).translation() | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
| print("\t-", Group, "Passed translation grad test") | |
| def test_vec_grad(Group, device='cuda', tol=1e-6): | |
| D = Group.manifold_dim | |
| X = Group.exp(5*torch.randn(1,D, device=device).double()) | |
| def fn(a): | |
| return (Group.exp(a)*X).vec() | |
| a = torch.zeros(1, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
| print("\t-", Group, "Passed tovec grad test") | |
| def test_fromvec_grad(Group, device='cuda', tol=1e-6): | |
| def fn(a): | |
| if Group == SO3: | |
| a = a / a.norm(dim=-1, keepdim=True) | |
| elif Group == RxSO3: | |
| q, s = a.split([4, 1], dim=-1) | |
| q = q / q.norm(dim=-1, keepdim=True) | |
| a = torch.cat([q, s.exp()], dim=-1) | |
| elif Group == SE3: | |
| t, q = a.split([3, 4], dim=-1) | |
| q = q / q.norm(dim=-1, keepdim=True) | |
| a = torch.cat([t, q], dim=-1) | |
| elif Group == Sim3: | |
| t, q, s = a.split([3, 4, 1], dim=-1) | |
| q = q / q.norm(dim=-1, keepdim=True) | |
| a = torch.cat([t, q, s.exp()], dim=-1) | |
| return Group.InitFromVec(a).vec() | |
| D = Group.embedded_dim | |
| a = torch.randn(1, 2, D, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a], eps=1e-4) | |
| assert torch.allclose(analytical[0], numerical[0], atol=tol) | |
| print("\t-", Group, "Passed fromvec grad test") | |
| def scale(device='cuda'): | |
| def fn(a, s): | |
| X = SE3.exp(a) | |
| X.scale(s) | |
| return X.log() | |
| s = torch.rand(1, requires_grad=True, device=device).double() | |
| a = torch.randn(1, 6, requires_grad=True, device=device).double() | |
| analytical, numerical = gradcheck(fn, [a, s], eps=1e-3) | |
| print(analytical[1]) | |
| print(numerical[1]) | |
| assert torch.allclose(analytical[0], numerical[0], atol=1e-8) | |
| assert torch.allclose(analytical[1], numerical[1], atol=1e-8) | |
| print("\t-", "Passed se3-to-sim3 test") | |
| if __name__ == '__main__': | |
| print("Testing lietorch forward pass (CPU) ...") | |
| for Group in [SO3, RxSO3, SE3, Sim3]: | |
| test_exp_log(Group, device='cpu') | |
| test_inv(Group, device='cpu') | |
| test_adj(Group, device='cpu') | |
| test_act(Group, device='cpu') | |
| print("Testing lietorch backward pass (CPU)...") | |
| for Group in [SO3, RxSO3, SE3, Sim3]: | |
| if Group == Sim3: | |
| tol = 1e-3 | |
| else: | |
| tol = 1e-8 | |
| test_exp_log_grad(Group, device='cpu', tol=tol) | |
| test_inv_log_grad(Group, device='cpu', tol=tol) | |
| test_adj_grad(Group, device='cpu') | |
| test_adjT_grad(Group, device='cpu') | |
| test_act_grad(Group, device='cpu') | |
| test_matrix_grad(Group, device='cpu') | |
| extract_translation_grad(Group, device='cpu') | |
| test_vec_grad(Group, device='cpu') | |
| test_fromvec_grad(Group, device='cpu') | |
| print("Testing lietorch forward pass (GPU) ...") | |
| for Group in [SO3, RxSO3, SE3, Sim3]: | |
| test_exp_log(Group, device='cuda') | |
| test_inv(Group, device='cuda') | |
| test_adj(Group, device='cuda') | |
| test_act(Group, device='cuda') | |
| print("Testing lietorch backward pass (GPU)...") | |
| for Group in [SO3, RxSO3, SE3, Sim3]: | |
| if Group == Sim3: | |
| tol = 1e-3 | |
| else: | |
| tol = 1e-8 | |
| test_exp_log_grad(Group, device='cuda', tol=tol) | |
| test_inv_log_grad(Group, device='cuda', tol=tol) | |
| test_adj_grad(Group, device='cuda') | |
| test_adjT_grad(Group, device='cuda') | |
| test_act_grad(Group, device='cuda') | |
| test_matrix_grad(Group, device='cuda') | |
| extract_translation_grad(Group, device='cuda') | |
| test_vec_grad(Group, device='cuda') | |
| test_fromvec_grad(Group, device='cuda') | |