0xZohar commited on
Commit
9fb53ab
·
verified ·
1 Parent(s): 83aa6ec

Add code/cube3d/model/autoencoder/embedder.py

Browse files
code/cube3d/model/autoencoder/embedder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class PhaseModulatedFourierEmbedder(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ num_freqs: int,
11
+ input_dim: int = 3,
12
+ ):
13
+ """
14
+ Initializes the PhaseModulatedFourierEmbedder class.
15
+ Args:
16
+ num_freqs (int): The number of frequencies to be used.
17
+ input_dim (int, optional): The dimension of the input. Defaults to 3.
18
+ Attributes:
19
+ weight (torch.nn.Parameter): The weight parameter initialized with random values.
20
+ carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem.
21
+ out_dim (int): The output dimension calculated based on the input dimension and number of frequencies.
22
+ """
23
+
24
+ super().__init__()
25
+
26
+ self.weight = nn.Parameter(
27
+ torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs)
28
+ )
29
+
30
+ # NOTE this is the highest frequency we can get (2 for peaks, 2 for zeros, and 4 for interpolation points), see also https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem
31
+ carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs)
32
+ carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi
33
+ self.register_buffer("carrier", carrier, persistent=False)
34
+
35
+ self.out_dim = input_dim * (num_freqs * 2 + 1)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Perform the forward pass of the embedder model.
40
+ Args:
41
+ x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim).
42
+ Returns:
43
+ torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where
44
+ output_dim = input_dim + 2 * input_dim.
45
+ """
46
+
47
+ m = x.float().unsqueeze(-1)
48
+ fm = (m * self.weight).view(*x.shape[:-1], -1)
49
+ pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1)
50
+ embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1)
51
+
52
+ return embedding