File size: 3,049 Bytes
19b8775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Based on
https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
"""

import math

import torch
from torch import nn

class SinusoidalEncoding(nn.Module):
    """
    Uses sine & cosine to represent position
    """
    def __init__(self, model_dim, max_len):
        super().__init__()
        self.register_buffer('pe', self.build_position(model_dim, max_len))

    @staticmethod
    def build_position(model_dim, max_len, device=None):
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(max_len, model_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        if device is not None:
            pe = pe.to(device=device)
        return pe

    def forward(self, x):
        if max(x) >= self.pe.shape[0]:
            # try to drop the reference first before creating a new encoding
            # the goal being to save memory if we are close to the memory limit
            device = self.pe.device
            shape = self.pe.shape[1]
            self.register_buffer('pe', None)
            # TODO: this may result in very poor performance
            # in the event of a model that increases size one at a time
            self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
        return self.pe[x]

    def max_len(self):
        return self.pe.shape[0]


class AddSinusoidalEncoding(nn.Module):
    """
    Uses sine & cosine to represent position.  Adds the position to the given matrix

    Default behavior is batch_first
    """
    def __init__(self, d_model=256, max_len=512):
        super().__init__()
        self.encoding = SinusoidalEncoding(d_model, max_len)

    def forward(self, x, scale=1.0):
        """
        Adds the positional encoding to the input tensor

        The tensor is expected to be of the shape B, N, D
        Properly masking the output tensor is up to the caller
        """
        if len(x.shape) == 3:
            timing = self.encoding(torch.arange(x.shape[1], device=x.device))
            timing = timing.expand(x.shape[0], -1, -1)
        elif len(x.shape) == 2:
            timing = self.encoding(torch.arange(x.shape[0], device=x.device))
        return x + timing * scale


class ConcatSinusoidalEncoding(nn.Module):
    """
    Uses sine & cosine to represent position.  Concats the position and returns a larger object

    Default behavior is batch_first
    """
    def __init__(self, d_model=256, max_len=512):
        super().__init__()
        self.encoding = SinusoidalEncoding(d_model, max_len)

    def forward(self, x):
        if len(x.shape) == 3:
            timing = self.encoding(torch.arange(x.shape[1], device=x.device))
            timing = timing.expand(x.shape[0], -1, -1)
        else:
            timing = self.encoding(torch.arange(x.shape[0], device=x.device))

        out = torch.cat((x, timing), dim=-1)
        return out