English
Shanci's picture
Upload folder using huggingface_hub
26225c5 verified
import torch
__all__ = ['fourier_position_encoder']
def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1):
"""
Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50%
of the encoding dimensions are untouched and free to use. This is
important when the positional encoding is added to learned feature
embeddings. If the positional encoding uses too much of the encoding
dimensions, it may be detrimental for the embeddings.
The default `f_min` and `f_max` values are set so as to ensure
a '~50% use of the encoding dimensions' and a '~1e-3 precision in
the position encoding if pos is 1D'.
:param pos: [M, M] Tensor
Positions are expected to be in [-1, 1]
:param dim: int
Number of encoding dimensions, size of the encoding space. Note
that increasing this is NOT the most direct way of improving
spatial encoding precision or compactness. See `f_min` and
`f_max` instead
:param f_min: float
Lower bound for the frequency range. Rules how much 'room' the
positional encodings leave in the encoding space for additive
embeddings
:param f_max: float
Upper bound for the frequency range. Rules how precise the
encoding can be. Increase this if you need to capture finer
spatial details
:return:
"""
assert pos.abs().max() <= 1, "Positions must be in [-1, 1]"
assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor"
# We preferably operate 2D tensors
if pos.dim() == 1:
pos = pos.view(-1, 1)
# Make sure M divides dim
N, M = pos.shape
D = dim // M
# assert dim % M == 0, "`dim` must be a multiple of the number of input spatial dimensions"
# assert D % 2 == 0, "`dim / M` must be a even number"
# To avoid uncomfortable border effects with -1 and +1 coordinates
# having the same (or very close) encodings, we convert [-1, 1]
# coordinates to [-π/2, π/2] for safety
pos = pos * torch.pi / 2
# Compute frequencies on a logarithmic range from f_min to f_max
device = pos.device
f_min = torch.tensor([f_min], device=device)
f_max = torch.tensor([f_max], device=device)
w = torch.logspace(f_max.log(), f_min.log(), D, device=device)
# Compute sine and cosine encodings
pos_enc = pos.view(N, M, 1) * w.view(1, -1)
pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos()
pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin()
pos_enc = pos_enc.view(N, -1)
# In case dim is not a multiple of 2 * M, we pad missing dimensions
# with zeros
if pos_enc.shape[1] < dim:
zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device)
pos_enc = torch.hstack((pos_enc, zeros))
return pos_enc