| | import torch |
| |
|
| |
|
| | __all__ = ['fourier_position_encoder'] |
| |
|
| |
|
| | def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1): |
| | """ |
| | Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50% |
| | of the encoding dimensions are untouched and free to use. This is |
| | important when the positional encoding is added to learned feature |
| | embeddings. If the positional encoding uses too much of the encoding |
| | dimensions, it may be detrimental for the embeddings. |
| | |
| | The default `f_min` and `f_max` values are set so as to ensure |
| | a '~50% use of the encoding dimensions' and a '~1e-3 precision in |
| | the position encoding if pos is 1D'. |
| | |
| | :param pos: [M, M] Tensor |
| | Positions are expected to be in [-1, 1] |
| | :param dim: int |
| | Number of encoding dimensions, size of the encoding space. Note |
| | that increasing this is NOT the most direct way of improving |
| | spatial encoding precision or compactness. See `f_min` and |
| | `f_max` instead |
| | :param f_min: float |
| | Lower bound for the frequency range. Rules how much 'room' the |
| | positional encodings leave in the encoding space for additive |
| | embeddings |
| | :param f_max: float |
| | Upper bound for the frequency range. Rules how precise the |
| | encoding can be. Increase this if you need to capture finer |
| | spatial details |
| | :return: |
| | """ |
| | assert pos.abs().max() <= 1, "Positions must be in [-1, 1]" |
| | assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor" |
| |
|
| | |
| | if pos.dim() == 1: |
| | pos = pos.view(-1, 1) |
| |
|
| | |
| | N, M = pos.shape |
| | D = dim // M |
| | |
| | |
| |
|
| | |
| | |
| | |
| | pos = pos * torch.pi / 2 |
| |
|
| | |
| | device = pos.device |
| | f_min = torch.tensor([f_min], device=device) |
| | f_max = torch.tensor([f_max], device=device) |
| | w = torch.logspace(f_max.log(), f_min.log(), D, device=device) |
| |
|
| | |
| | pos_enc = pos.view(N, M, 1) * w.view(1, -1) |
| | pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos() |
| | pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin() |
| | pos_enc = pos_enc.view(N, -1) |
| |
|
| | |
| | |
| | if pos_enc.shape[1] < dim: |
| | zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device) |
| | pos_enc = torch.hstack((pos_enc, zeros)) |
| |
|
| | return pos_enc |
| |
|