import math import torch def encode_single(d_model, value, max_period=10000.0): """ :param d_model: dimension of the model :param value: the value to encode :param max_period: the maximum allowed value :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model), ) pe = torch.zeros(d_model) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(max_period) / d_model), ) pe[0::2] = torch.sin(value * div_term) pe[1::2] = torch.cos(value * div_term) return pe def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half, ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def offset_sequence_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: an (N, T) Tensor of sequences of time offsets :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, T, dim) Tensor of positional embeddings. """ N, T = t.shape flattened = torch.flatten(t) embedding = timestep_embedding(flattened, dim, max_period) return torch.reshape(embedding, (N, T, dim)) def position_sequence_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: an (N, T, D) Tensor of sequences of D dimensional positions. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, T, D * dim) Tensor of positional embeddings. """ N, T, D = t.shape flattened = torch.flatten(t) embedding = timestep_embedding(flattened, dim, max_period) return torch.reshape(embedding, (N, T, D * dim)) def positionalencoding(d_model, values, max_period=10000.0): """ :param d_model: dimension of the model :param values: the values to encode :param max_period: the maximum allowed value :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model), ) pe = torch.zeros(len(values), d_model) position = values.unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(max_period) / d_model), ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe def positionalencoding1d(d_model, length): """ :param d_model: dimension of the model :param length: length of positions :return: length*d_model position matrix """ if d_model % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model), ) pe = torch.zeros(2, d_model) position = torch.arange(-50, 50, 100).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model), ) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) return pe def positionalencoding2d(d_model, height, width): """ :param d_model: dimension of the model :param height: height of the positions :param width: width of the positions :return: d_model*height*width position matrix """ if d_model % 4 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model), ) pe = torch.zeros(d_model, height, width) # Each dimension use half of d_model d_model = int(d_model / 2) div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)) pos_w = torch.arange(0.0, width).unsqueeze(1) pos_h = torch.arange(0.0, height).unsqueeze(1) pe[0:d_model:2, :, :] = ( torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) ) pe[1:d_model:2, :, :] = ( torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) ) pe[d_model::2, :, :] = ( torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) ) pe[d_model + 1 :: 2, :, :] = ( torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) ) return pe if __name__ == "__main__": import matplotlib.pyplot as plt pe = positionalencoding(128, torch.tensor([-50, 50])) plt.imshow(pe) plt.show()