Spaces:
Paused
Paused
Add code/cube3d/model/autoencoder/embedder.py
Browse files
code/cube3d/model/autoencoder/embedder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PhaseModulatedFourierEmbedder(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
num_freqs: int,
|
| 11 |
+
input_dim: int = 3,
|
| 12 |
+
):
|
| 13 |
+
"""
|
| 14 |
+
Initializes the PhaseModulatedFourierEmbedder class.
|
| 15 |
+
Args:
|
| 16 |
+
num_freqs (int): The number of frequencies to be used.
|
| 17 |
+
input_dim (int, optional): The dimension of the input. Defaults to 3.
|
| 18 |
+
Attributes:
|
| 19 |
+
weight (torch.nn.Parameter): The weight parameter initialized with random values.
|
| 20 |
+
carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem.
|
| 21 |
+
out_dim (int): The output dimension calculated based on the input dimension and number of frequencies.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.weight = nn.Parameter(
|
| 27 |
+
torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# NOTE this is the highest frequency we can get (2 for peaks, 2 for zeros, and 4 for interpolation points), see also https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
|
| 31 |
+
carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
|
| 32 |
+
carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi
|
| 33 |
+
self.register_buffer("carrier", carrier, persistent=False)
|
| 34 |
+
|
| 35 |
+
self.out_dim = input_dim * (num_freqs * 2 + 1)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
"""
|
| 39 |
+
Perform the forward pass of the embedder model.
|
| 40 |
+
Args:
|
| 41 |
+
x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim).
|
| 42 |
+
Returns:
|
| 43 |
+
torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where
|
| 44 |
+
output_dim = input_dim + 2 * input_dim.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
m = x.float().unsqueeze(-1)
|
| 48 |
+
fm = (m * self.weight).view(*x.shape[:-1], -1)
|
| 49 |
+
pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1)
|
| 50 |
+
embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1)
|
| 51 |
+
|
| 52 |
+
return embedding
|