Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from einops import einsum, rearrange, repeat | |
| from jaxtyping import Float | |
| from torch import Tensor | |
| class PositionalEncoding(nn.Module): | |
| """For the sake of simplicity, this encodes values in the range [0, 1].""" | |
| frequencies: Float[Tensor, "frequency phase"] | |
| phases: Float[Tensor, "frequency phase"] | |
| def __init__(self, num_octaves: int): | |
| super().__init__() | |
| octaves = torch.arange(num_octaves).float() | |
| # The lowest frequency has a period of 1. | |
| frequencies = 2 * torch.pi * 2**octaves | |
| frequencies = repeat(frequencies, "f -> f p", p=2) | |
| self.register_buffer("frequencies", frequencies, persistent=False) | |
| # Choose the phases to match sine and cosine. | |
| phases = torch.tensor([0, 0.5 * torch.pi], dtype=torch.float32) | |
| phases = repeat(phases, "p -> f p", f=num_octaves) | |
| self.register_buffer("phases", phases, persistent=False) | |
| def forward( | |
| self, | |
| samples: Float[Tensor, "*batch dim"], | |
| ) -> Float[Tensor, "*batch embedded_dim"]: | |
| samples = einsum(samples, self.frequencies, "... d, f p -> ... d f p") | |
| return rearrange(torch.sin(samples + self.phases), "... d f p -> ... (d f p)") | |
| def d_out(self, dimensionality: int): | |
| return self.frequencies.numel() * dimensionality | |