| |
| |
| |
| |
| |
|
|
| import abc |
|
|
| import torch.nn as nn |
| from torch import Tensor |
|
|
|
|
| class Manifold(nn.Module, metaclass=abc.ABCMeta): |
| """A manifold class that contains projection operations and logarithm and exponential maps.""" |
|
|
| @abc.abstractmethod |
| def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
| r"""Computes exponential map :math:`\exp_x(u)`. |
| |
| Args: |
| x (Tensor): point on the manifold |
| u (Tensor): tangent vector at point :math:`x` |
| |
| Raises: |
| NotImplementedError: if not implemented |
| |
| Returns: |
| Tensor: transported point |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
| r"""Computes logarithmic map :math:`\log_x(y)`. |
| |
| Args: |
| x (Tensor): point on the manifold |
| y (Tensor): point on the manifold |
| |
| Raises: |
| NotImplementedError: if not implemented |
| |
| Returns: |
| Tensor: tangent vector at point :math:`x` |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def projx(self, x: Tensor) -> Tensor: |
| """Project point :math:`x` on the manifold. |
| |
| Args: |
| x (Tensor): point to be projected |
| |
| Raises: |
| NotImplementedError: if not implemented |
| |
| Returns: |
| Tensor: projected point on the manifold |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def proju(self, x: Tensor, u: Tensor) -> Tensor: |
| """Project vector :math:`u` on a tangent space for :math:`x`. |
| |
| Args: |
| x (Tensor): point on the manifold |
| u (Tensor): vector to be projected |
| |
| Raises: |
| NotImplementedError: if not implemented |
| |
| Returns: |
| Tensor: projected tangent vector |
| """ |
| raise NotImplementedError |
|
|
|
|
| class Euclidean(Manifold): |
| """The Euclidean manifold.""" |
|
|
| def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
| return x + u |
|
|
| def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
| return y - x |
|
|
| def projx(self, x: Tensor) -> Tensor: |
| return x |
|
|
| def proju(self, x: Tensor, u: Tensor) -> Tensor: |
| return u |
|
|