| | import torch |
| | import torch.nn as nn |
| | from einops import einsum, rearrange, repeat |
| | from jaxtyping import Float |
| | from torch import Tensor |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| | """For the sake of simplicity, this encodes values in the range [0, 1].""" |
| |
|
| | frequencies: Float[Tensor, "frequency phase"] |
| | phases: Float[Tensor, "frequency phase"] |
| |
|
| | def __init__(self, num_octaves: int): |
| | super().__init__() |
| | octaves = torch.arange(num_octaves).float() |
| | |
| | |
| | frequencies = 2 * torch.pi * 2**octaves |
| | frequencies = repeat(frequencies, "f -> f p", p=2) |
| | self.register_buffer("frequencies", frequencies, persistent=False) |
| |
|
| | |
| | phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) |
| | phases = repeat(phases, "p -> f p", f=num_octaves) |
| | self.register_buffer("phases", phases, persistent=False) |
| |
|
| | def forward( |
| | self, |
| | samples: Float[Tensor, "*batch dim"], |
| | ) -> Float[Tensor, "*batch embedded_dim"]: |
| | samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") |
| | return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") |
| |
|
| | def d_out(self, dimensionality: int): |
| | return self.frequencies.numel() * dimensionality |
| |
|