import torch from torch import nn import numpy as np import math from .common import _cal_freq_list """ Grid, SphereC, SphereCPlus, SphereM, SphereMPlus location encoders """ class GridAndSphere(nn.Module): """ Given a list of (deltaX,deltaY), encode them using the position encoding function """ def __init__(self, coord_dim=2, frequency_num=16, max_radius=0.01, min_radius=0.00001, freq_init="geometric", name="grid"): """ Args: coord_dim: the dimention of space, 2D, 3D, or other frequency_num: the number of different sinusoidal with different frequencies/wavelengths max_radius: the largest context radius this model can handle """ super(GridAndSphere, self).__init__() # change name attribute to emulate the subclass if name == "grid": GridAndSphere.__qualname__ = "Grid" GridAndSphere.__name__ = "Grid" elif name == "spherec": GridAndSphere.__qualname__ = "SphereC" GridAndSphere.__name__ = "SphereC" elif name == "spherecplus": GridAndSphere.__qualname__ = "SphereCPlus" GridAndSphere.__name__ = "SphereCPlus" elif name == "spherem": GridAndSphere.__qualname__ = "SphereM" GridAndSphere.__name__ = "SphereM" elif name == "spheremplus": GridAndSphere.__qualname__ = "SphereMPlus" GridAndSphere.__name__ = "SphereMPlus" self.coord_dim = coord_dim self.frequency_num = frequency_num self.freq_init = freq_init self.max_radius = max_radius self.min_radius = min_radius # the frequence we use for each block, alpha in ICLR paper self.cal_freq_list() self.cal_freq_mat() self.name = name self.embedding_dim = self.cal_embedding_dim() def cal_elementwise_angle(self, coord, cur_freq): ''' Args: coord: the deltaX or deltaY cur_freq: the frequency ''' return coord / (np.power(self.max_radius, cur_freq * 1.0 / (self.frequency_num - 1))) def cal_coord_embed(self, coords_tuple): embed = [] for coord in coords_tuple: for cur_freq in range(self.frequency_num): embed.append(math.sin(self.cal_elementwise_angle(coord, cur_freq))) embed.append(math.cos(self.cal_elementwise_angle(coord, cur_freq))) # embed: shape (input_embed_dim) return embed def cal_embedding_dim(self): # compute the dimention of the encoded spatial relation embedding if self.name == "grid": return int(4 * self.frequency_num) elif self.name == "spherec": return int(6 * self.frequency_num) # xyz instead of lon lat elif self.name == "spherecplus": return int(12 * self.frequency_num) elif self.name == "spherem": return int(10 * self.frequency_num) elif self.name == "spheremplus": return int(16 * self.frequency_num) # FIX def cal_freq_list(self): self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) def cal_freq_mat(self): # freq_mat shape: (frequency_num, 1) freq_mat = np.expand_dims(self.freq_list, axis=1) # self.freq_mat shape: (frequency_num, 2) self.freq_mat = np.repeat(freq_mat, 2, axis=1) def forward(self, coords): device = coords.device dtype = coords.dtype N = coords.size(0) # add 1 context point dimension (unused here) coords = coords[:, None, :] # coords_mat: shape (batch_size, num_context_pt, 2) coords_mat = np.asarray(coords.cpu()) batch_size = coords_mat.shape[0] num_context_pt = coords_mat.shape[1] # coords_mat: shape (batch_size, num_context_pt, 2, 1) coords_mat = np.expand_dims(coords_mat, axis=3) # coords_mat: shape (batch_size, num_context_pt, 2, 1, 1) coords_mat = np.expand_dims(coords_mat, axis=4) # coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 1) coords_mat = np.repeat(coords_mat, self.frequency_num, axis=3) # coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 2) coords_mat = np.repeat(coords_mat, 2, axis=4) # spr_embeds: shape (batch_size, num_context_pt, 2, frequency_num, 2) spr_embeds = coords_mat * self.freq_mat if self.name == "grid": # eq 3 in https://arxiv.org/pdf/2201.10489.pdf # code from https://github.com/gengchenmai/space2vec/blob/a29793336e6a1ebdb497289c286a0b4d5a83079f/spacegraph/spacegraph_codebase/SpatialRelationEncoder.py#L135 spr_embeds[:, :, :, :, 0::2] = np.sin(spr_embeds[:, :, :, :, 0::2]) # dim 2i spr_embeds[:, :, :, :, 1::2] = np.cos(spr_embeds[:, :, :, :, 1::2]) # dim 2i+1 elif self.name == "spherec": # eq 4 in https://arxiv.org/pdf/2201.10489.pdf # lambda: longitude, theta=latitude #sin_lon, sin_lat = np.sin(spr_embeds[:, 0, :, :, 0]).transpose(1, 0, 2) #cos_lon, cos_lat = np.cos(spr_embeds[:, 0, :, :, 1]).transpose(1, 0, 2) # eq 4 # sin theta, cos_theta * cos_lambda, cos_theta * sin_lambda # sin lat, cos_lat cos_lon, cos_lat sin_lon #spr_embeds = np.stack([sin_lat, cos_lat*cos_lon, cos_lat*sin_lon], axis=-1) spr_embeds = spr_embeds# * math.pi / 180 # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_sin = np.sin(lon) lon_cos = np.cos(lon) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_sin = np.sin(lat) lat_cos = np.cos(lat) # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) spr_embeds_ = np.concatenate([lat_sin, lat_cos * lon_cos, lat_cos * lon_sin], axis=-1) # (batch_size, num_context_pt, frequency_num*3) spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) elif self.name == "spherecplus": # eq 10 in https://arxiv.org/pdf/2201.10489.pdf (basically grid + spherec) spr_embeds = spr_embeds# * math.pi / 180 # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_sin = np.sin(lon) lon_cos = np.cos(lon) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_sin = np.sin(lat) lat_cos = np.cos(lat) # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 6) spr_embeds_ = np.concatenate([lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_cos, lat_cos * lon_sin], axis=-1) # (batch_size, num_context_pt, 2*frequency_num*6) spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) elif self.name == "spherem": """code from https://github.com/gengchenmai/sphere2vec/blob/8e923bbceab6065cbb4f26398122a5a6f08e0135/main/SpatialRelationEncoder.py#L1753""" # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_single_sin = np.sin(lon_single) lon_single_cos = np.cos(lon_single) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_single_sin = np.sin(lat_single) lat_single_cos = np.cos(lat_single) # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_sin = np.sin(lon) lon_cos = np.cos(lon) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_sin = np.sin(lat) lat_cos = np.cos(lat) # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) spr_embeds = np.concatenate([lat_sin, lat_cos * lon_single_cos, lat_single_cos * lon_cos, lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) elif self.name == "spheremplus": # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_single_sin = np.sin(lon_single) lon_single_cos = np.cos(lon_single) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_single_sin = np.sin(lat_single) lat_single_cos = np.cos(lat_single) # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) # make sinuniod function # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lon_sin = np.sin(lon) lon_cos = np.cos(lon) # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) lat_sin = np.sin(lat) lat_cos = np.cos(lat) # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) spr_embeds = np.concatenate( [lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_single_cos, lat_single_cos * lon_cos, lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) return torch.from_numpy(spr_embeds.reshape(N, -1)).to(dtype).to(device)