|
|
import torch |
|
|
|
|
|
|
|
|
__all__ = ['fourier_position_encoder'] |
|
|
|
|
|
|
|
|
def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1): |
|
|
""" |
|
|
Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50% |
|
|
of the encoding dimensions are untouched and free to use. This is |
|
|
important when the positional encoding is added to learned feature |
|
|
embeddings. If the positional encoding uses too much of the encoding |
|
|
dimensions, it may be detrimental for the embeddings. |
|
|
|
|
|
The default `f_min` and `f_max` values are set so as to ensure |
|
|
a '~50% use of the encoding dimensions' and a '~1e-3 precision in |
|
|
the position encoding if pos is 1D'. |
|
|
|
|
|
:param pos: [M, M] Tensor |
|
|
Positions are expected to be in [-1, 1] |
|
|
:param dim: int |
|
|
Number of encoding dimensions, size of the encoding space. Note |
|
|
that increasing this is NOT the most direct way of improving |
|
|
spatial encoding precision or compactness. See `f_min` and |
|
|
`f_max` instead |
|
|
:param f_min: float |
|
|
Lower bound for the frequency range. Rules how much 'room' the |
|
|
positional encodings leave in the encoding space for additive |
|
|
embeddings |
|
|
:param f_max: float |
|
|
Upper bound for the frequency range. Rules how precise the |
|
|
encoding can be. Increase this if you need to capture finer |
|
|
spatial details |
|
|
:return: |
|
|
""" |
|
|
assert pos.abs().max() <= 1, "Positions must be in [-1, 1]" |
|
|
assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor" |
|
|
|
|
|
|
|
|
if pos.dim() == 1: |
|
|
pos = pos.view(-1, 1) |
|
|
|
|
|
|
|
|
N, M = pos.shape |
|
|
D = dim // M |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pos = pos * torch.pi / 2 |
|
|
|
|
|
|
|
|
device = pos.device |
|
|
f_min = torch.tensor([f_min], device=device) |
|
|
f_max = torch.tensor([f_max], device=device) |
|
|
w = torch.logspace(f_max.log(), f_min.log(), D, device=device) |
|
|
|
|
|
|
|
|
pos_enc = pos.view(N, M, 1) * w.view(1, -1) |
|
|
pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos() |
|
|
pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin() |
|
|
pos_enc = pos_enc.view(N, -1) |
|
|
|
|
|
|
|
|
|
|
|
if pos_enc.shape[1] < dim: |
|
|
zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device) |
|
|
pos_enc = torch.hstack((pos_enc, zeros)) |
|
|
|
|
|
return pos_enc |
|
|
|