| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch import nn |
| from typing import Optional |
|
|
|
|
| Tensor = torch.Tensor |
|
|
|
|
| class NeRFEncoding(nn.Module): |
| """Multi-scale sinusoidal encodings. |
| Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. |
| |
| Args: |
| in_dim: Input dimension of tensor |
| num_frequencies: Number of encoded frequencies per axis |
| min_freq_exp: Minimum frequency exponent |
| max_freq_exp: Maximum frequency exponent |
| include_input: Append the input coordinate to the encoding |
| """ |
|
|
| def __init__( |
| self, |
| in_dim: int, |
| num_frequencies: int, |
| min_freq_exp: float = 0., |
| max_freq_exp: Optional[float] = None, |
| include_input: bool = False |
| ) -> None: |
| super().__init__() |
| if max_freq_exp is None: |
| max_freq_exp = num_frequencies - 1 |
|
|
| self.in_dim = in_dim |
| self.num_frequencies = num_frequencies |
| self.min_freq = min_freq_exp |
| self.max_freq = max_freq_exp |
| self.include_input = include_input |
|
|
| def get_out_dim(self) -> int: |
| if self.in_dim is None: |
| raise ValueError("Input dimension has not been set") |
| out_dim = self.in_dim * self.num_frequencies * 2 |
| if self.include_input: |
| out_dim += self.in_dim |
| return out_dim |
|
|
| def forward( |
| self, |
| in_tensor: Tensor |
| ) -> Tensor: |
| """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed |
| in mip-NeRF. |
| |
| Args: |
| in_tensor: For best performance, the input tensor should be between 0 and 1. [*bs, input_dim] |
| Returns: |
| Output values will be between -1 and 1. [*bs, output_dim] |
| """ |
|
|
| freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device) |
| scaled_inputs = in_tensor[..., None] * freqs |
| scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) |
| encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)) |
|
|
| if self.include_input: |
| encoded_inputs = torch.cat([in_tensor, encoded_inputs], dim=-1) |
|
|
| return encoded_inputs |
|
|