Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from .spherical_harmonics_ylm import SH as SH_analytic | |
| from .spherical_harmonics_closed_form import SH as SH_closed_form | |
| """ | |
| Spherical Harmonics locaiton encoder | |
| """ | |
| class SphericalHarmonics(nn.Module): | |
| def __init__(self, legendre_polys: int = 10, harmonics_calculation="analytic"): | |
| """ | |
| legendre_polys: determines the number of legendre polynomials. | |
| more polynomials lead more fine-grained resolutions | |
| calculation of spherical harmonics: | |
| analytic uses pre-computed equations. This is exact, but works only up to degree 50, | |
| closed-form uses one equation but is computationally slower (especially for high degrees) | |
| """ | |
| super(SphericalHarmonics, self).__init__() | |
| self.L, self.M = int(legendre_polys), int(legendre_polys) | |
| self.embedding_dim = self.L * self.M | |
| if harmonics_calculation == "closed-form": | |
| self.SH = SH_closed_form | |
| elif harmonics_calculation == "analytic": | |
| self.SH = SH_analytic | |
| def forward(self, lonlat): | |
| lon, lat = lonlat[:, 0], lonlat[:, 1] | |
| # convert degree to rad | |
| phi = torch.deg2rad(lon + 180) | |
| theta = torch.deg2rad(lat + 90) | |
| Y = [] | |
| for l in range(self.L): | |
| for m in range(-l, l + 1): | |
| y = self.SH(m, l, phi, theta) | |
| if isinstance(y, float): | |
| y = y * torch.ones_like(phi) | |
| Y.append(y) | |
| return torch.stack(Y,dim=-1) | |