File size: 5,676 Bytes
7ef7abb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | import math
import torch
def encode_single(d_model, value, max_period=10000.0):
"""
:param d_model: dimension of the model
:param value: the value to encode
:param max_period: the maximum allowed value
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(d_model)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* -(math.log(max_period) / d_model),
)
pe[0::2] = torch.sin(value * div_term)
pe[1::2] = torch.cos(value * div_term)
return pe
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
/ half,
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def offset_sequence_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: an (N, T) Tensor of sequences of time offsets
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, T, dim) Tensor of positional embeddings.
"""
N, T = t.shape
flattened = torch.flatten(t)
embedding = timestep_embedding(flattened, dim, max_period)
return torch.reshape(embedding, (N, T, dim))
def position_sequence_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: an (N, T, D) Tensor of sequences of D dimensional positions.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, T, D * dim) Tensor of positional embeddings.
"""
N, T, D = t.shape
flattened = torch.flatten(t)
embedding = timestep_embedding(flattened, dim, max_period)
return torch.reshape(embedding, (N, T, D * dim))
def positionalencoding(d_model, values, max_period=10000.0):
"""
:param d_model: dimension of the model
:param values: the values to encode
:param max_period: the maximum allowed value
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(len(values), d_model)
position = values.unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* -(math.log(max_period) / d_model),
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
def positionalencoding1d(d_model, length):
"""
:param d_model: dimension of the model
:param length: length of positions
:return: length*d_model position matrix
"""
if d_model % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model),
)
pe = torch.zeros(2, d_model)
position = torch.arange(-50, 50, 100).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model),
)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model),
)
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0.0, width).unsqueeze(1)
pos_h = torch.arange(0.0, height).unsqueeze(1)
pe[0:d_model:2, :, :] = (
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[1:d_model:2, :, :] = (
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
)
pe[d_model::2, :, :] = (
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
pe[d_model + 1 :: 2, :, :] = (
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
)
return pe
if __name__ == "__main__":
import matplotlib.pyplot as plt
pe = positionalencoding(128, torch.tensor([-50, 50]))
plt.imshow(pe)
plt.show()
|