|
|
import torch |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
from .common import _cal_freq_list |
|
|
|
|
|
""" |
|
|
Grid, SphereC, SphereCPlus, SphereM, SphereMPlus location encoders |
|
|
""" |
|
|
class GridAndSphere(nn.Module): |
|
|
""" |
|
|
Given a list of (deltaX,deltaY), encode them using the position encoding function |
|
|
""" |
|
|
|
|
|
def __init__(self, coord_dim=2, frequency_num=16, |
|
|
max_radius=0.01, min_radius=0.00001, |
|
|
freq_init="geometric", name="grid"): |
|
|
""" |
|
|
Args: |
|
|
coord_dim: the dimention of space, 2D, 3D, or other |
|
|
frequency_num: the number of different sinusoidal with different frequencies/wavelengths |
|
|
max_radius: the largest context radius this model can handle |
|
|
""" |
|
|
super(GridAndSphere, self).__init__() |
|
|
|
|
|
|
|
|
if name == "grid": |
|
|
GridAndSphere.__qualname__ = "Grid" |
|
|
GridAndSphere.__name__ = "Grid" |
|
|
elif name == "spherec": |
|
|
GridAndSphere.__qualname__ = "SphereC" |
|
|
GridAndSphere.__name__ = "SphereC" |
|
|
elif name == "spherecplus": |
|
|
GridAndSphere.__qualname__ = "SphereCPlus" |
|
|
GridAndSphere.__name__ = "SphereCPlus" |
|
|
elif name == "spherem": |
|
|
GridAndSphere.__qualname__ = "SphereM" |
|
|
GridAndSphere.__name__ = "SphereM" |
|
|
elif name == "spheremplus": |
|
|
GridAndSphere.__qualname__ = "SphereMPlus" |
|
|
GridAndSphere.__name__ = "SphereMPlus" |
|
|
|
|
|
self.coord_dim = coord_dim |
|
|
self.frequency_num = frequency_num |
|
|
self.freq_init = freq_init |
|
|
self.max_radius = max_radius |
|
|
self.min_radius = min_radius |
|
|
|
|
|
self.cal_freq_list() |
|
|
self.cal_freq_mat() |
|
|
self.name = name |
|
|
self.embedding_dim = self.cal_embedding_dim() |
|
|
|
|
|
|
|
|
def cal_elementwise_angle(self, coord, cur_freq): |
|
|
''' |
|
|
Args: |
|
|
coord: the deltaX or deltaY |
|
|
cur_freq: the frequency |
|
|
''' |
|
|
return coord / (np.power(self.max_radius, cur_freq * 1.0 / (self.frequency_num - 1))) |
|
|
|
|
|
def cal_coord_embed(self, coords_tuple): |
|
|
embed = [] |
|
|
for coord in coords_tuple: |
|
|
for cur_freq in range(self.frequency_num): |
|
|
embed.append(math.sin(self.cal_elementwise_angle(coord, cur_freq))) |
|
|
embed.append(math.cos(self.cal_elementwise_angle(coord, cur_freq))) |
|
|
|
|
|
return embed |
|
|
|
|
|
|
|
|
def cal_embedding_dim(self): |
|
|
|
|
|
|
|
|
if self.name == "grid": |
|
|
return int(4 * self.frequency_num) |
|
|
elif self.name == "spherec": |
|
|
return int(6 * self.frequency_num) |
|
|
elif self.name == "spherecplus": |
|
|
return int(12 * self.frequency_num) |
|
|
elif self.name == "spherem": |
|
|
return int(10 * self.frequency_num) |
|
|
elif self.name == "spheremplus": |
|
|
return int(16 * self.frequency_num) |
|
|
|
|
|
def cal_freq_list(self): |
|
|
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) |
|
|
|
|
|
def cal_freq_mat(self): |
|
|
|
|
|
freq_mat = np.expand_dims(self.freq_list, axis=1) |
|
|
|
|
|
self.freq_mat = np.repeat(freq_mat, 2, axis=1) |
|
|
|
|
|
def forward(self, coords): |
|
|
device = coords.device |
|
|
dtype = coords.dtype |
|
|
N = coords.size(0) |
|
|
|
|
|
|
|
|
coords = coords[:, None, :] |
|
|
|
|
|
|
|
|
coords_mat = np.asarray(coords.cpu()) |
|
|
batch_size = coords_mat.shape[0] |
|
|
num_context_pt = coords_mat.shape[1] |
|
|
|
|
|
coords_mat = np.expand_dims(coords_mat, axis=3) |
|
|
|
|
|
coords_mat = np.expand_dims(coords_mat, axis=4) |
|
|
|
|
|
coords_mat = np.repeat(coords_mat, self.frequency_num, axis=3) |
|
|
|
|
|
coords_mat = np.repeat(coords_mat, 2, axis=4) |
|
|
|
|
|
spr_embeds = coords_mat * self.freq_mat |
|
|
|
|
|
if self.name == "grid": |
|
|
|
|
|
|
|
|
|
|
|
spr_embeds[:, :, :, :, 0::2] = np.sin(spr_embeds[:, :, :, :, 0::2]) |
|
|
spr_embeds[:, :, :, :, 1::2] = np.cos(spr_embeds[:, :, :, :, 1::2]) |
|
|
|
|
|
elif self.name == "spherec": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spr_embeds = spr_embeds |
|
|
|
|
|
|
|
|
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) |
|
|
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_sin = np.sin(lon) |
|
|
lon_cos = np.cos(lon) |
|
|
|
|
|
|
|
|
lat_sin = np.sin(lat) |
|
|
lat_cos = np.cos(lat) |
|
|
|
|
|
|
|
|
spr_embeds_ = np.concatenate([lat_sin, lat_cos * lon_cos, lat_cos * lon_sin], axis=-1) |
|
|
|
|
|
|
|
|
spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) |
|
|
elif self.name == "spherecplus": |
|
|
|
|
|
spr_embeds = spr_embeds |
|
|
|
|
|
|
|
|
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) |
|
|
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_sin = np.sin(lon) |
|
|
lon_cos = np.cos(lon) |
|
|
|
|
|
|
|
|
lat_sin = np.sin(lat) |
|
|
lat_cos = np.cos(lat) |
|
|
|
|
|
|
|
|
spr_embeds_ = np.concatenate([lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_cos, lat_cos * lon_sin], |
|
|
axis=-1) |
|
|
|
|
|
|
|
|
spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1)) |
|
|
|
|
|
elif self.name == "spherem": |
|
|
"""code from https://github.com/gengchenmai/sphere2vec/blob/8e923bbceab6065cbb4f26398122a5a6f08e0135/main/SpatialRelationEncoder.py#L1753""" |
|
|
|
|
|
|
|
|
lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) |
|
|
lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_single_sin = np.sin(lon_single) |
|
|
lon_single_cos = np.cos(lon_single) |
|
|
|
|
|
|
|
|
lat_single_sin = np.sin(lat_single) |
|
|
lat_single_cos = np.cos(lat_single) |
|
|
|
|
|
|
|
|
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) |
|
|
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_sin = np.sin(lon) |
|
|
lon_cos = np.cos(lon) |
|
|
|
|
|
|
|
|
lat_sin = np.sin(lat) |
|
|
lat_cos = np.cos(lat) |
|
|
|
|
|
|
|
|
spr_embeds = np.concatenate([lat_sin, lat_cos * lon_single_cos, lat_single_cos * lon_cos, |
|
|
lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) |
|
|
|
|
|
elif self.name == "spheremplus": |
|
|
|
|
|
|
|
|
lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2) |
|
|
lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_single_sin = np.sin(lon_single) |
|
|
lon_single_cos = np.cos(lon_single) |
|
|
|
|
|
|
|
|
lat_single_sin = np.sin(lat_single) |
|
|
lat_single_cos = np.cos(lat_single) |
|
|
|
|
|
|
|
|
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2) |
|
|
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2) |
|
|
|
|
|
|
|
|
|
|
|
lon_sin = np.sin(lon) |
|
|
lon_cos = np.cos(lon) |
|
|
|
|
|
|
|
|
lat_sin = np.sin(lat) |
|
|
lat_cos = np.cos(lat) |
|
|
|
|
|
|
|
|
spr_embeds = np.concatenate( |
|
|
[lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_single_cos, lat_single_cos * lon_cos, |
|
|
lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1) |
|
|
|
|
|
|
|
|
return torch.from_numpy(spr_embeds.reshape(N, -1)).to(dtype).to(device) |
|
|
|