ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
import torch
from torch import nn
import numpy as np
import math
from .common import _cal_freq_list
"""
Grid, SphereC, SphereCPlus, SphereM, SphereMPlus location encoders
"""
class GridAndSphere(nn.Module):
"""
Given a list of (deltaX,deltaY), encode them using the position encoding function
"""
def __init__(self, coord_dim=2, frequency_num=16,
max_radius=0.01, min_radius=0.00001,
freq_init="geometric", name="grid"):
"""
Args:
coord_dim: the dimention of space, 2D, 3D, or other
frequency_num: the number of different sinusoidal with different frequencies/wavelengths
max_radius: the largest context radius this model can handle
"""
super(GridAndSphere, self).__init__()
# change name attribute to emulate the subclass
if name == "grid":
GridAndSphere.__qualname__ = "Grid"
GridAndSphere.__name__ = "Grid"
elif name == "spherec":
GridAndSphere.__qualname__ = "SphereC"
GridAndSphere.__name__ = "SphereC"
elif name == "spherecplus":
GridAndSphere.__qualname__ = "SphereCPlus"
GridAndSphere.__name__ = "SphereCPlus"
elif name == "spherem":
GridAndSphere.__qualname__ = "SphereM"
GridAndSphere.__name__ = "SphereM"
elif name == "spheremplus":
GridAndSphere.__qualname__ = "SphereMPlus"
GridAndSphere.__name__ = "SphereMPlus"
self.coord_dim = coord_dim
self.frequency_num = frequency_num
self.freq_init = freq_init
self.max_radius = max_radius
self.min_radius = min_radius
# the frequence we use for each block, alpha in ICLR paper
self.cal_freq_list()
self.cal_freq_mat()
self.name = name
self.embedding_dim = self.cal_embedding_dim()
def cal_elementwise_angle(self, coord, cur_freq):
'''
Args:
coord: the deltaX or deltaY
cur_freq: the frequency
'''
return coord / (np.power(self.max_radius, cur_freq * 1.0 / (self.frequency_num - 1)))
def cal_coord_embed(self, coords_tuple):
embed = []
for coord in coords_tuple:
for cur_freq in range(self.frequency_num):
embed.append(math.sin(self.cal_elementwise_angle(coord, cur_freq)))
embed.append(math.cos(self.cal_elementwise_angle(coord, cur_freq)))
# embed: shape (input_embed_dim)
return embed
def cal_embedding_dim(self):
# compute the dimention of the encoded spatial relation embedding
if self.name == "grid":
return int(4 * self.frequency_num)
elif self.name == "spherec":
return int(6 * self.frequency_num) # xyz instead of lon lat
elif self.name == "spherecplus":
return int(12 * self.frequency_num)
elif self.name == "spherem":
return int(10 * self.frequency_num)
elif self.name == "spheremplus":
return int(16 * self.frequency_num) # FIX
def cal_freq_list(self):
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius)
def cal_freq_mat(self):
# freq_mat shape: (frequency_num, 1)
freq_mat = np.expand_dims(self.freq_list, axis=1)
# self.freq_mat shape: (frequency_num, 2)
self.freq_mat = np.repeat(freq_mat, 2, axis=1)
def forward(self, coords):
device = coords.device
dtype = coords.dtype
N = coords.size(0)
# add 1 context point dimension (unused here)
coords = coords[:, None, :]
# coords_mat: shape (batch_size, num_context_pt, 2)
coords_mat = np.asarray(coords.cpu())
batch_size = coords_mat.shape[0]
num_context_pt = coords_mat.shape[1]
# coords_mat: shape (batch_size, num_context_pt, 2, 1)
coords_mat = np.expand_dims(coords_mat, axis=3)
# coords_mat: shape (batch_size, num_context_pt, 2, 1, 1)
coords_mat = np.expand_dims(coords_mat, axis=4)
# coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 1)
coords_mat = np.repeat(coords_mat, self.frequency_num, axis=3)
# coords_mat: shape (batch_size, num_context_pt, 2, frequency_num, 2)
coords_mat = np.repeat(coords_mat, 2, axis=4)
# spr_embeds: shape (batch_size, num_context_pt, 2, frequency_num, 2)
spr_embeds = coords_mat * self.freq_mat
if self.name == "grid":
# eq 3 in https://arxiv.org/pdf/2201.10489.pdf
# code from https://github.com/gengchenmai/space2vec/blob/a29793336e6a1ebdb497289c286a0b4d5a83079f/spacegraph/spacegraph_codebase/SpatialRelationEncoder.py#L135
spr_embeds[:, :, :, :, 0::2] = np.sin(spr_embeds[:, :, :, :, 0::2]) # dim 2i
spr_embeds[:, :, :, :, 1::2] = np.cos(spr_embeds[:, :, :, :, 1::2]) # dim 2i+1
elif self.name == "spherec":
# eq 4 in https://arxiv.org/pdf/2201.10489.pdf
# lambda: longitude, theta=latitude
#sin_lon, sin_lat = np.sin(spr_embeds[:, 0, :, :, 0]).transpose(1, 0, 2)
#cos_lon, cos_lat = np.cos(spr_embeds[:, 0, :, :, 1]).transpose(1, 0, 2)
# eq 4
# sin theta, cos_theta * cos_lambda, cos_theta * sin_lambda
# sin lat, cos_lat cos_lon, cos_lat sin_lon
#spr_embeds = np.stack([sin_lat, cos_lat*cos_lon, cos_lat*sin_lon], axis=-1)
spr_embeds = spr_embeds# * math.pi / 180
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2)
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_sin = np.sin(lon)
lon_cos = np.cos(lon)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_sin = np.sin(lat)
lat_cos = np.cos(lat)
# spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3)
spr_embeds_ = np.concatenate([lat_sin, lat_cos * lon_cos, lat_cos * lon_sin], axis=-1)
# (batch_size, num_context_pt, frequency_num*3)
spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1))
elif self.name == "spherecplus":
# eq 10 in https://arxiv.org/pdf/2201.10489.pdf (basically grid + spherec)
spr_embeds = spr_embeds# * math.pi / 180
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2)
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_sin = np.sin(lon)
lon_cos = np.cos(lon)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_sin = np.sin(lat)
lat_cos = np.cos(lat)
# spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 6)
spr_embeds_ = np.concatenate([lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_cos, lat_cos * lon_sin],
axis=-1)
# (batch_size, num_context_pt, 2*frequency_num*6)
spr_embeds = np.reshape(spr_embeds_, (batch_size, num_context_pt, -1))
elif self.name == "spherem":
"""code from https://github.com/gengchenmai/sphere2vec/blob/8e923bbceab6065cbb4f26398122a5a6f08e0135/main/SpatialRelationEncoder.py#L1753"""
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2)
lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_single_sin = np.sin(lon_single)
lon_single_cos = np.cos(lon_single)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_single_sin = np.sin(lat_single)
lat_single_cos = np.cos(lat_single)
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2)
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_sin = np.sin(lon)
lon_cos = np.cos(lon)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_sin = np.sin(lat)
lat_cos = np.cos(lat)
# spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3)
spr_embeds = np.concatenate([lat_sin, lat_cos * lon_single_cos, lat_single_cos * lon_cos,
lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1)
elif self.name == "spheremplus":
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_single = np.expand_dims(coords_mat[:, :, 0, :, :], axis=2)
lat_single = np.expand_dims(coords_mat[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_single_sin = np.sin(lon_single)
lon_single_cos = np.cos(lon_single)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_single_sin = np.sin(lat_single)
lat_single_cos = np.cos(lat_single)
# lon, lat: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon = np.expand_dims(spr_embeds[:, :, 0, :, :], axis=2)
lat = np.expand_dims(spr_embeds[:, :, 1, :, :], axis=2)
# make sinuniod function
# lon_sin, lon_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lon_sin = np.sin(lon)
lon_cos = np.cos(lon)
# lat_sin, lat_cos: shape (batch_size, num_context_pt, 1, frequency_num, 1)
lat_sin = np.sin(lat)
lat_cos = np.cos(lat)
# spr_embeds_: shape (batch_size, num_context_pt, 1, frequency_num, 3)
spr_embeds = np.concatenate(
[lat_sin, lat_cos, lon_sin, lon_cos, lat_cos * lon_single_cos, lat_single_cos * lon_cos,
lat_cos * lon_single_sin, lat_single_cos * lon_sin], axis=-1)
return torch.from_numpy(spr_embeds.reshape(N, -1)).to(dtype).to(device)