Spaces:
Sleeping
Sleeping
| import lietorch_backends | |
| import torch | |
| import torch.nn.functional as F | |
| class GroupOp(torch.autograd.Function): | |
| """ group operation base class """ | |
| def forward(cls, ctx, group_id, *inputs): | |
| ctx.group_id = group_id | |
| ctx.save_for_backward(*inputs) | |
| out = cls.forward_op(ctx.group_id, *inputs) | |
| return out | |
| def backward(cls, ctx, grad): | |
| error_str = "Backward operation not implemented for {}".format(cls) | |
| assert cls.backward_op is not None, error_str | |
| inputs = ctx.saved_tensors | |
| grad = grad.contiguous() | |
| grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs) | |
| return (None, ) + tuple(grad_inputs) | |
| class Exp(GroupOp): | |
| """ exponential map """ | |
| forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward | |
| class Log(GroupOp): | |
| """ logarithm map """ | |
| forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward | |
| class Inv(GroupOp): | |
| """ group inverse """ | |
| forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward | |
| class Mul(GroupOp): | |
| """ group multiplication """ | |
| forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward | |
| class Adj(GroupOp): | |
| """ adjoint operator """ | |
| forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward | |
| class AdjT(GroupOp): | |
| """ adjoint operator """ | |
| forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward | |
| class Act3(GroupOp): | |
| """ action on point """ | |
| forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward | |
| class Act4(GroupOp): | |
| """ action on point """ | |
| forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward | |
| class Jinv(GroupOp): | |
| """ adjoint operator """ | |
| forward_op, backward_op = lietorch_backends.Jinv, None | |
| class ToMatrix(GroupOp): | |
| """ convert to matrix representation """ | |
| forward_op, backward_op = lietorch_backends.as_matrix, None | |
| ### conversion operations to/from Euclidean embeddings ### | |
| class FromVec(torch.autograd.Function): | |
| """ convert vector into group object """ | |
| def forward(cls, ctx, group_id, *inputs): | |
| ctx.group_id = group_id | |
| ctx.save_for_backward(*inputs) | |
| return inputs[0] | |
| def backward(cls, ctx, grad): | |
| inputs = ctx.saved_tensors | |
| J = lietorch_backends.projector(ctx.group_id, *inputs) | |
| return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2) | |
| class ToVec(torch.autograd.Function): | |
| """ convert group object to vector """ | |
| def forward(cls, ctx, group_id, *inputs): | |
| ctx.group_id = group_id | |
| ctx.save_for_backward(*inputs) | |
| return inputs[0] | |
| def backward(cls, ctx, grad): | |
| inputs = ctx.saved_tensors | |
| J = lietorch_backends.projector(ctx.group_id, *inputs) | |
| return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2) | |