| | |
| | |
| | |
| |
|
| | import math |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | def get_emb(sin_inp: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Gets a base embedding for one dimension with sin and cos intertwined |
| | """ |
| | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) |
| | return torch.flatten(emb, -2, -1) |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, |
| | dim: int, |
| | scale: float = math.pi * 2, |
| | temperature: float = 10000, |
| | normalize: bool = True, |
| | channel_last: bool = True, |
| | transpose_output: bool = False): |
| | super().__init__() |
| | dim = int(np.ceil(dim / 4) * 2) |
| | self.dim = dim |
| | inv_freq = 1.0 / (temperature**(torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq) |
| | self.normalize = normalize |
| | self.scale = scale |
| | self.eps = 1e-6 |
| | self.channel_last = channel_last |
| | self.transpose_output = transpose_output |
| |
|
| | self.cached_penc = None |
| |
|
| | def forward(self, tensor: torch.Tensor) -> torch.Tensor: |
| | """ |
| | :param tensor: A 4/5d tensor of size |
| | channel_last=True: (batch_size, h, w, c) or (batch_size, k, h, w, c) |
| | channel_last=False: (batch_size, c, h, w) or (batch_size, k, c, h, w) |
| | :return: positional encoding tensor that has the same shape as the input if the input is 4d |
| | if the input is 5d, the output is broadcastable along the k-dimension |
| | """ |
| | if len(tensor.shape) != 4 and len(tensor.shape) != 5: |
| | raise RuntimeError(f'The input tensor has to be 4/5d, got {tensor.shape}!') |
| |
|
| | if len(tensor.shape) == 5: |
| | |
| | num_objects = tensor.shape[1] |
| | tensor = tensor[:, 0] |
| | else: |
| | num_objects = None |
| |
|
| | if self.channel_last: |
| | batch_size, h, w, c = tensor.shape |
| | else: |
| | batch_size, c, h, w = tensor.shape |
| |
|
| | if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: |
| | if num_objects is None: |
| | return self.cached_penc |
| | else: |
| | return self.cached_penc.unsqueeze(1) |
| |
|
| | self.cached_penc = None |
| |
|
| | pos_y = torch.arange(h, device=tensor.device, dtype=self.inv_freq.dtype) |
| | pos_x = torch.arange(w, device=tensor.device, dtype=self.inv_freq.dtype) |
| | if self.normalize: |
| | pos_y = pos_y / (pos_y[-1] + self.eps) * self.scale |
| | pos_x = pos_x / (pos_x[-1] + self.eps) * self.scale |
| |
|
| | sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) |
| | sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) |
| | emb_y = get_emb(sin_inp_y).unsqueeze(1) |
| | emb_x = get_emb(sin_inp_x) |
| |
|
| | emb = torch.zeros((h, w, self.dim * 2), device=tensor.device, dtype=tensor.dtype) |
| | emb[:, :, :self.dim] = emb_x |
| | emb[:, :, self.dim:] = emb_y |
| |
|
| | if not self.channel_last and self.transpose_output: |
| | |
| | pass |
| | elif (not self.channel_last) or (self.transpose_output): |
| | emb = emb.permute(2, 0, 1) |
| |
|
| | self.cached_penc = emb.unsqueeze(0).repeat(batch_size, 1, 1, 1) |
| | if num_objects is None: |
| | return self.cached_penc |
| | else: |
| | return self.cached_penc.unsqueeze(1) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | pe = PositionalEncoding(8).cuda() |
| | input = torch.ones((1, 8, 8, 8)).cuda() |
| | output = pe(input) |
| | |
| | print(output[0, :, 0, 0]) |
| | print(output[0, :, 0, 5]) |
| | print(output[0, 0, :, 0]) |
| | print(output[0, 0, 0, :]) |
| |
|