|
|
import torch |
|
|
from torch import nn |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
from .common import _cal_freq_list |
|
|
|
|
|
""" |
|
|
Theory based location encoder |
|
|
""" |
|
|
class Theory(nn.Module): |
|
|
""" |
|
|
Given a list of (deltaX,deltaY), encode them using the position encoding function |
|
|
""" |
|
|
|
|
|
def __init__(self, coord_dim=2, frequency_num=16, |
|
|
max_radius=10000, min_radius=1000, freq_init="geometric"): |
|
|
""" |
|
|
Args: |
|
|
coord_dim: the dimention of space, 2D, 3D, or other |
|
|
frequency_num: the number of different sinusoidal with different frequencies/wavelengths |
|
|
max_radius: the largest context radius this model can handle |
|
|
""" |
|
|
super(Theory, self).__init__() |
|
|
self.frequency_num = frequency_num |
|
|
self.coord_dim = coord_dim |
|
|
self.max_radius = max_radius |
|
|
self.min_radius = min_radius |
|
|
self.freq_init = freq_init |
|
|
|
|
|
|
|
|
self.cal_freq_list() |
|
|
self.cal_freq_mat() |
|
|
|
|
|
|
|
|
self.unit_vec1 = np.asarray([1.0, 0.0]) |
|
|
self.unit_vec2 = np.asarray([-1.0 / 2.0, math.sqrt(3) / 2.0]) |
|
|
self.unit_vec3 = np.asarray([-1.0 / 2.0, -math.sqrt(3) / 2.0]) |
|
|
|
|
|
self.embedding_dim = self.cal_embedding_dim() |
|
|
|
|
|
def cal_freq_list(self): |
|
|
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) |
|
|
|
|
|
def cal_freq_mat(self): |
|
|
|
|
|
freq_mat = np.expand_dims(self.freq_list, axis=1) |
|
|
|
|
|
self.freq_mat = np.repeat(freq_mat, 6, axis=1) |
|
|
|
|
|
def cal_embedding_dim(self): |
|
|
|
|
|
return int(2 * 3 * self.frequency_num) |
|
|
|
|
|
def forward(self, coords): |
|
|
device = coords.device |
|
|
dtype = coords.dtype |
|
|
N = coords.size(0) |
|
|
|
|
|
|
|
|
coords_mat = np.asarray(coords.cpu()) |
|
|
batch_size = coords_mat.shape[0] |
|
|
num_context_pt = coords_mat.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis=-1) |
|
|
|
|
|
angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis=-1) |
|
|
|
|
|
angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis=-1) |
|
|
|
|
|
|
|
|
angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis=-1) |
|
|
|
|
|
angle_mat = np.expand_dims(angle_mat, axis=-2) |
|
|
|
|
|
angle_mat = np.repeat(angle_mat, self.frequency_num, axis=-2) |
|
|
|
|
|
angle_mat = angle_mat * self.freq_mat |
|
|
|
|
|
spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) |
|
|
spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) |
|
|
|
|
|
return torch.from_numpy(spr_embeds.reshape(N,-1)).to(dtype).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|