File size: 3,049 Bytes
19b8775 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
"""
Based on
https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
"""
import math
import torch
from torch import nn
class SinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position
"""
def __init__(self, model_dim, max_len):
super().__init__()
self.register_buffer('pe', self.build_position(model_dim, max_len))
@staticmethod
def build_position(model_dim, max_len, device=None):
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
pe = torch.zeros(max_len, model_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
if device is not None:
pe = pe.to(device=device)
return pe
def forward(self, x):
if max(x) >= self.pe.shape[0]:
# try to drop the reference first before creating a new encoding
# the goal being to save memory if we are close to the memory limit
device = self.pe.device
shape = self.pe.shape[1]
self.register_buffer('pe', None)
# TODO: this may result in very poor performance
# in the event of a model that increases size one at a time
self.register_buffer('pe', self.build_position(shape, max(x)+1, device=device))
return self.pe[x]
def max_len(self):
return self.pe.shape[0]
class AddSinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position. Adds the position to the given matrix
Default behavior is batch_first
"""
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.encoding = SinusoidalEncoding(d_model, max_len)
def forward(self, x, scale=1.0):
"""
Adds the positional encoding to the input tensor
The tensor is expected to be of the shape B, N, D
Properly masking the output tensor is up to the caller
"""
if len(x.shape) == 3:
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
timing = timing.expand(x.shape[0], -1, -1)
elif len(x.shape) == 2:
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
return x + timing * scale
class ConcatSinusoidalEncoding(nn.Module):
"""
Uses sine & cosine to represent position. Concats the position and returns a larger object
Default behavior is batch_first
"""
def __init__(self, d_model=256, max_len=512):
super().__init__()
self.encoding = SinusoidalEncoding(d_model, max_len)
def forward(self, x):
if len(x.shape) == 3:
timing = self.encoding(torch.arange(x.shape[1], device=x.device))
timing = timing.expand(x.shape[0], -1, -1)
else:
timing = self.encoding(torch.arange(x.shape[0], device=x.device))
out = torch.cat((x, timing), dim=-1)
return out
|