Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import math | |
| from .common import _cal_freq_list | |
| """ | |
| Grid, SphereC, SphereCPlus, SphereM, SphereMPlus location encoders | |
| """ | |
| class GridAndSphere(nn.Module): | |
| """ | |
| Given a list of (deltaX,deltaY), encode them using the position encoding function | |
| """ | |
| def __init__(self, coord_dim=2, frequency_num=16, | |
| max_radius=0.01, min_radius=0.00001, | |
| freq_init="geometric", name="grid"): | |
| """ | |
| Args: | |
| coord_dim: the dimention of space, 2D, 3D, or other | |
| frequency_num: the number of different sinusoidal with different frequencies/wavelengths | |
| max_radius: the largest context radius this model can handle | |
| """ | |
| super(GridAndSphere, self).__init__() | |
| # change name attribute to emulate the subclass | |
| if name == "grid": | |
| GridAndSphere.__qualname__ = "Grid" | |
| GridAndSphere.__name__ = "Grid" | |
| elif name == "spherec": | |
| GridAndSphere.__qualname__ = "SphereC" | |
| GridAndSphere.__name__ = "SphereC" | |
| elif name == "spherecplus": | |
| GridAndSphere.__qualname__ = "SphereCPlus" | |
| GridAndSphere.__name__ = "SphereCPlus" | |
| elif name == "spherem": | |
| GridAndSphere.__qualname__ = "SphereM" | |
| GridAndSphere.__name__ = "SphereM" | |
| elif name == "spheremplus": | |
| GridAndSphere.__qualname__ = "SphereMPlus" | |
| GridAndSphere.__name__ = "SphereMPlus" | |
| self.coord_dim = coord_dim | |
| self.frequency_num = frequency_num | |
| self.freq_init = freq_init | |
| self.max_radius = max_radius | |
| self.min_radius = min_radius | |
| # the frequence we use for each block, alpha in ICLR paper | |
| self.cal_freq_list() | |
| self.cal_freq_mat() | |
| self.name = name | |
| self.embedding_dim = self.cal_embedding_dim() | |
| def cal_elementwise_angle(self, coord, cur_freq): | |
| ''' | |
| Args: | |
| coord: the deltaX or deltaY | |
| cur_freq: the frequency | |
| ''' | |
| return coord / (np.power(self.max_radius, cur_freq * 1.0 / (self.frequency_num - 1))) | |
| def cal_coord_embed(self, coords_tuple): | |
| embed = [] | |
| for coord in coords_tuple: | |
| for cur_freq in range(self.frequency_num): | |
| embed.append(math.sin(self.cal_elementwise_angle(coord, cur_freq))) | |
| embed.append(math.cos(self.cal_elementwise_angle(coord, cur_freq))) | |
| # embed: shape (input_embed_dim) | |
| return embed | |
| def cal_embedding_dim(self): | |
| # compute the dimention of the encoded spatial relation embedding | |
| if self.name == "grid": | |
| return int(4 * self.frequency_num) | |
| elif self.name == "spherec": | |
| return int(6 * self.frequency_num) # xyz instead of lon lat | |
| elif self.name == "spherecplus": | |
| return int(12 * self.frequency_num) | |
| elif self.name == "spherem": | |
| return int(10 * self.frequency_num) | |
| elif self.name == "spheremplus": | |
| return int(16 * self.frequency_num) # FIX | |
| def cal_freq_list(self): | |
| self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) | |
| def cal_freq_mat(self): | |
| # freq_mat shape: (frequency_num, 1) | |
| freq_mat = np.expand_dims(self.freq_list, axis=1) | |
| # self.freq_mat shape: (frequency_num, 2) | |
| self.freq_mat = np.repeat(freq_mat, 2, axis=1) | |
| def forward(self, coords): | |
| device = coords.device | |
| dtype = coords.dtype | |
| N = coords.size(0) | |
| # add 1 context point dimension (unused here) | |
| coords = coords[:, None, :] | |
| # coords_mat: shape (batch_size, num_context_pt, 2) | |
| coords_mat = np.asarray(coords.cpu()) | |
| batch_size = coords_mat.shape[0] | |
| num_context_pt = coords_mat.shape[1] | |
| # coords_mat: shape (batch_size, num_context_pt, 2, 1) | |
| coords_mat = np.expand_dims(coords_mat, axis=3) | |
| # coords_mat: shape (batch_size, num_context_pt, 2, 1, 1) | |
| coords_mat = np.expand_dims(coords_mat, axis=4) | |
| # coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 1) | |
| coords_mat = np.repeat(coords_mat, self.frequency_num, axis=3) | |
| # coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 2) | |
| coords_mat = np.repeat(coords_mat, 2, axis=4) | |
| # spr_embeds: shape (batch_size, num_context_pt, 2, frequency_num, 2) | |
| spr_embeds = coords_mat * self.freq_mat | |
| if self.name == "grid": | |
| # eq 3 in https://arxiv.org/pdf/2201.10489.pdf | |
| # code from https://github.com/gengchenmai/space2vec/blob/a29793336e6a1ebdb497289c286a0b4d5a83079f/spacegraph/spacegraph_codebase/SpatialRelationEncoder.py#L135 | |
| spr_embeds[:, :, :, :, 0::2] = np.sin(spr_embeds[:, :, :, :, 0::2]) # dim 2i | |
| spr_embeds[:, :, :, :, 1::2] = np.cos(spr_embeds[:, :, :, :, 1::2]) # dim 2i+1 | |
| elif self.name == "spherec": | |
| # eq 4 in https://arxiv.org/pdf/2201.10489.pdf | |
| # lambda: longitude, theta=latitude | |
| #sin_lon, sin_lat = np.sin(spr_embeds[:, 0, :, :, 0]).transpose(1, 0, 2) | |
| #cos_lon, cos_lat = np.cos(spr_embeds[:, 0, :, :, 1]).transpose(1, 0, 2) | |
| # eq 4 | |
| # sin theta, cos_theta * cos_lambda, cos_theta * sin_lambda | |
| # sin lat, cos_lat cos_lon, cos_lat sin_lon | |
| #spr_embeds = np.stack([sin_lat, cos_lat*cos_lon, cos_lat*sin_lon], axis=-1) | |
| spr_embeds = spr_embeds# * math.pi / 180 | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) | |
| lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_sin = np.sin(lon) | |
| lon_cos = np.cos(lon) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_sin = np.sin(lat) | |
| lat_cos = np.cos(lat) | |
| # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) | |
| spr_embeds_ = np.concatenate([lat_sin, lat_cos * lon_cos, lat_cos * lon_sin], axis=-1) | |
| # (batch_size, num_context_pt, frequency_num*3) | |
| spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) | |
| elif self.name == "spherecplus": | |
| # eq 10 in https://arxiv.org/pdf/2201.10489.pdf (basically grid + spherec) | |
| spr_embeds = spr_embeds# * math.pi / 180 | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) | |
| lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_sin = np.sin(lon) | |
| lon_cos = np.cos(lon) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_sin = np.sin(lat) | |
| lat_cos = np.cos(lat) | |
| # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 6) | |
| spr_embeds_ = np.concatenate([lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_cos, lat_cos * lon_sin], | |
| axis=-1) | |
| # (batch_size, num_context_pt, 2*frequency_num*6) | |
| spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) | |
| elif self.name == "spherem": | |
| """code from https://github.com/gengchenmai/sphere2vec/blob/8e923bbceab6065cbb4f26398122a5a6f08e0135/main/SpatialRelationEncoder.py#L1753""" | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) | |
| lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_single_sin = np.sin(lon_single) | |
| lon_single_cos = np.cos(lon_single) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_single_sin = np.sin(lat_single) | |
| lat_single_cos = np.cos(lat_single) | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) | |
| lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_sin = np.sin(lon) | |
| lon_cos = np.cos(lon) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_sin = np.sin(lat) | |
| lat_cos = np.cos(lat) | |
| # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) | |
| spr_embeds = np.concatenate([lat_sin, lat_cos * lon_single_cos, lat_single_cos * lon_cos, | |
| lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) | |
| elif self.name == "spheremplus": | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) | |
| lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_single_sin = np.sin(lon_single) | |
| lon_single_cos = np.cos(lon_single) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_single_sin = np.sin(lat_single) | |
| lat_single_cos = np.cos(lat_single) | |
| # lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) | |
| lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) | |
| # make sinuniod function | |
| # lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lon_sin = np.sin(lon) | |
| lon_cos = np.cos(lon) | |
| # lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1) | |
| lat_sin = np.sin(lat) | |
| lat_cos = np.cos(lat) | |
| # spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3) | |
| spr_embeds = np.concatenate( | |
| [lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_single_cos, lat_single_cos * lon_cos, | |
| lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) | |
| return torch.from_numpy(spr_embeds.reshape(N, -1)).to(dtype).to(device) | |