File size: 2,793 Bytes
5d2c747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
A collection of positional encoding modules.
"""
import torch
import math
class LearnedPosEncoding(torch.nn.Module):
"""
Basic learned positional encoding
"""
def __init__(self, hidden_dim, context_window):
super().__init__()
self.pe = torch.nn.Embedding(
num_embeddings=context_window, embedding_dim=hidden_dim
)
def forward(self, x):
"""
Takes the input tensor and returns it positionally encoded.
Args:
x: torch.tensor(B, S, H)
Returns:
x: torch.tensor(B, S, H)
"""
if len(x.shape) >= 2:
return x + (self.pe(torch.arange(x.size(1), device=x.device)).unsqueeze(0))
else:
return x + self.pe(torch.arange(x.size(1), device=x.device))
class IdentityEncoding(torch.nn.Module):
"""
In case RoPE is used, there is no need for an initial positional encoding.
"""
def __init__(self):
super().__init__()
def forward(self, x):
"""
Returns the input tensor as is.
"""
return x
class SinCosPosEncoding(
torch.nn.Module
):
"""SinCos encoding taken from:
\\url{https://github.com/pytorch/examples/blob/main/word_language_model/model.py#L65}
As used in the Vaiswani et al. paper..."""
def __init__(self, hidden_dim, context_window):
"""Set up the pe buffer etc."""
super().__init__()
pe = torch.zeros(context_window, hidden_dim)
position = torch.arange(0, context_window, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # pe has shape (1, S, H)
self.pe = torch.nn.Parameter(pe) # hack for distributed data parallel
self.pe.requires_grad = False
def forward(self, x):
"""Add the pe to the input tensor."""
# x of shape (B, S, H)
return x + self.pe[:, :x.size(1)]
POS_ENCODING_DICT = {
"learned": lambda dim, size, **_: LearnedPosEncoding(
hidden_dim=dim, context_window=size
),
"rope": lambda **_: IdentityEncoding(),
"none": lambda **_: IdentityEncoding(),
"sincos": lambda dim, size, **_: SinCosPosEncoding(
hidden_dim=dim, context_window=size
),
}
def build_positional_encodings(model_cfg):
"""
Given the positional encoding config, build it.
Args:
cfg: cfg
Returns:
positional_encodings: positional_encodings_instance
"""
return POS_ENCODING_DICT[model_cfg["positional_encoding_type"]](
dim=model_cfg["hidden_dim"], size=model_cfg["context_window"]
)
|