| |
| |
| |
| |
| |
|
|
| import math |
|
|
| import torch |
| from torch import Tensor |
|
|
| from flow_matching.utils.manifolds import Manifold |
|
|
|
|
| class FlatTorus(Manifold): |
| r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres.""" |
|
|
| def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
| return (x + u) % (2 * math.pi) |
|
|
| def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
| return torch.atan2(torch.sin(y - x), torch.cos(y - x)) |
|
|
| def projx(self, x: Tensor) -> Tensor: |
| return x % (2 * math.pi) |
|
|
| def proju(self, x: Tensor, u: Tensor) -> Tensor: |
| return u |
|
|