import torch from torch import nn from .spherical_harmonics_ylm import SH def SH_(args): return SH(*args) """ Discretized Spherical Harmonics """ class DiscretizedSphericalHarmonics(nn.Module): def __init__(self, legendre_polys: int = 10): """ legendre_polys: determines the number of legendre polynomials. more polynomials lead more fine-grained resolutions embedding_dims: determines the dimension of the embedding. """ super(DiscretizedSphericalHarmonics, self).__init__() self.L, self.M = int(legendre_polys), int(legendre_polys) self.embedding_dim = self.L * self.M lon = torch.tensor(torch.linspace(-180, 180, 360)) lat = torch.tensor(torch.linspace(-90, 90, 180)) lons, lats = torch.meshgrid(lon, lat) # ij indexing to xy indexing lons, lats = lons.T, lats.T phi = torch.deg2rad(lons + 180) theta = torch.deg2rad(lats + 90) Ys = [] for l in range(self.L): for m in range(-l, l + 1): Ys.append(SH(m, l, phi, theta) * torch.ones_like(phi)) self.Ys = torch.stack(Ys) self.Ys = self.Ys.permute(0, 2, 1) def forward(self, lonlat): lonlat = lonlat + torch.tensor([180,90], device=lonlat.device) Ys = interpolate_pixel_values(self.Ys.to(lonlat.device), lonlat).T return Ys def get_coeffs(self, l, m): """ convenience function to store two triangle matrices in one where m can be negative """ if m == 0: return self.weight[l, 0] if m > 0: # on diagnoal and right of it return self.weight[l, m] if m < 0: # left of diagonal return self.weight[-l, m] def get_weight_matrix(self): """ a convenience function to restructure the weight matrix (L x M x E) into a double triangle matrix (L x 2 * L + 1 x E) where with legrende polynomials are on the rows and frequency components -m ... m on the columns. """ unfolded_coeffs = torch.zeros(self.L, self.L * 2 + 1, self.E, device=self.weight.device) for l in range(0, self.L): for m in range(-l, l + 1): unfolded_coeffs[l, m + self.L] = self.get_coeffs(l, m) return unfolded_coeffs def interpolate_pixel_values(image, points): num_points = len(points) rows, cols = image.size()[1], image.size()[2] # Convert sub-pixel coordinates to integer indices floor_coords = torch.floor(points).long() ceil_coords = torch.ceil(points).long() # Compute fractional parts for interpolation weights frac_coords = points - floor_coords.float() # Clamp the indices to ensure they are within image boundaries floor_coords[:, 0] = torch.clamp(floor_coords[:, 0], 0, rows - 1) floor_coords[:, 1] = torch.clamp(floor_coords[:, 1], 0, cols - 1) ceil_coords[:, 0] = torch.clamp(ceil_coords[:, 0], 0, rows - 1) ceil_coords[:, 1] = torch.clamp(ceil_coords[:, 1], 0, cols - 1) # Extract pixel values from the image floor_pixels = image[:, floor_coords[:, 0], floor_coords[:, 1]] ceil_pixels = image[:, ceil_coords[:, 0], ceil_coords[:, 1]] # Compute interpolation weights weights_floor = (1 - frac_coords[:, 0]) * (1 - frac_coords[:, 1]) weights_ceil = frac_coords[:, 0] * (1 - frac_coords[:, 1]) weights = torch.stack([weights_floor, weights_ceil], dim=1) # Interpolate pixel values interpolated_pixels = torch.sum(torch.stack([floor_pixels, ceil_pixels], dim=2) * weights.view(1, num_points, 2), dim=2) return interpolated_pixels