English
File size: 2,779 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch


__all__ = ['fourier_position_encoder']


def fourier_position_encoder(pos, dim, f_min=1e-1, f_max=1e1):
    """
    Heuristic: keeping ```f_min = 1 / f_max``` ensures that roughly 50%
    of the encoding dimensions are untouched and free to use. This is
    important when the positional encoding is added to learned feature
    embeddings. If the positional encoding uses too much of the encoding
    dimensions, it may be detrimental for the embeddings.

    The default `f_min` and `f_max` values are set so as to ensure
    a '~50% use of the encoding dimensions' and a '~1e-3 precision in
    the position encoding if pos is 1D'.

    :param pos: [M, M] Tensor
        Positions are expected to be in [-1, 1]
    :param dim: int
        Number of encoding dimensions, size of the encoding space. Note
        that increasing this is NOT the most direct way of improving
        spatial encoding precision or compactness. See `f_min` and
        `f_max` instead
    :param f_min: float
        Lower bound for the frequency range. Rules how much 'room' the
        positional encodings leave in the encoding space for additive
        embeddings
    :param f_max: float
        Upper bound for the frequency range. Rules how precise the
        encoding can be. Increase this if you need to capture finer
        spatial details
    :return:
    """
    assert pos.abs().max() <= 1, "Positions must be in [-1, 1]"
    assert 1 <= pos.dim() <= 2, "Positions must be a 1D or 2D tensor"

    # We preferably operate 2D tensors
    if pos.dim() == 1:
        pos = pos.view(-1, 1)

    # Make sure M divides dim
    N, M = pos.shape
    D = dim // M
    # assert dim % M == 0, "`dim` must be a multiple of the number of input spatial dimensions"
    # assert D % 2 == 0, "`dim / M` must be a even number"

    # To avoid uncomfortable border effects with -1 and +1 coordinates
    # having the same (or very close) encodings, we convert [-1, 1]
    # coordinates to [-π/2, π/2] for safety
    pos = pos * torch.pi / 2

    # Compute frequencies on a logarithmic range from f_min to f_max
    device = pos.device
    f_min = torch.tensor([f_min], device=device)
    f_max = torch.tensor([f_max], device=device)
    w = torch.logspace(f_max.log(), f_min.log(), D, device=device)

    # Compute sine and cosine encodings
    pos_enc = pos.view(N, M, 1) * w.view(1, -1)
    pos_enc[:, :, ::2] = pos_enc[:, :, ::2].cos()
    pos_enc[:, :, 1::2] = pos_enc[:, :, 1::2].sin()
    pos_enc = pos_enc.view(N, -1)

    # In case dim is not a multiple of 2 * M, we pad missing dimensions
    # with zeros
    if pos_enc.shape[1] < dim:
        zeros = torch.zeros(N, dim - pos_enc.shape[1], device=device)
        pos_enc = torch.hstack((pos_enc, zeros))

    return pos_enc