| |
| |
| |
| |
| |
|
|
| |
|
|
| from typing import Optional |
|
|
| import torch |
|
|
|
|
| class HarmonicEmbedding(torch.nn.Module): |
| def __init__( |
| self, |
| n_harmonic_functions: int = 6, |
| omega_0: float = 1.0, |
| logspace: bool = True, |
| append_input: bool = True, |
| ) -> None: |
| """ |
| The harmonic embedding layer supports the classical |
| Nerf positional encoding described in |
| `NeRF <https://arxiv.org/abs/2003.08934>`_ |
| and the integrated position encoding in |
| `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. |
| |
| During the inference you can provide the extra argument `diag_cov`. |
| |
| If `diag_cov is None`, it converts |
| rays parametrized with a `ray_bundle` to 3D points by |
| extending each ray according to the corresponding length. |
| Then it converts each feature |
| (i.e. vector along the last dimension) in `x` |
| into a series of harmonic features `embedding`, |
| where for each i in range(dim) the following are present |
| in embedding[...]:: |
| |
| [ |
| sin(f_1*x[..., i]), |
| sin(f_2*x[..., i]), |
| ... |
| sin(f_N * x[..., i]), |
| cos(f_1*x[..., i]), |
| cos(f_2*x[..., i]), |
| ... |
| cos(f_N * x[..., i]), |
| x[..., i], # only present if append_input is True. |
| ] |
| |
| where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar |
| denoting the i-th frequency of the harmonic embedding. |
| |
| |
| If `diag_cov is not None`, it approximates |
| conical frustums following a ray bundle as gaussians, |
| defined by x, the means of the gaussians and diag_cov, |
| the diagonal covariances. |
| Then it converts each gaussian |
| into a series of harmonic features `embedding`, |
| where for each i in range(dim) the following are present |
| in embedding[...]:: |
| |
| [ |
| sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), |
| sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]), |
| ... |
| sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), |
| cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]), |
| cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),, |
| ... |
| cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]), |
| x[..., i], # only present if append_input is True. |
| ] |
| |
| where N equals `n_harmonic_functions-1`, and f_i is a scalar |
| denoting the i-th frequency of the harmonic embedding. |
| |
| If `logspace==True`, the frequencies `[f_1, ..., f_N]` are |
| powers of 2: |
| `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` |
| |
| If `logspace==False`, frequencies are linearly spaced between |
| `1.0` and `2**(n_harmonic_functions-1)`: |
| `f_1, ..., f_N = torch.linspace( |
| 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions |
| )` |
| |
| Note that `x` is also premultiplied by the base frequency `omega_0` |
| before evaluating the harmonic functions. |
| |
| Args: |
| n_harmonic_functions: int, number of harmonic |
| features |
| omega_0: float, base frequency |
| logspace: bool, Whether to space the frequencies in |
| logspace or linear space |
| append_input: bool, whether to concat the original |
| input to the harmonic embedding. If true the |
| output is of the form (embed.sin(), embed.cos(), x) |
| """ |
| super().__init__() |
|
|
| if logspace: |
| frequencies = 2.0 ** torch.arange( |
| n_harmonic_functions, |
| dtype=torch.float32, |
| ) |
| else: |
| frequencies = torch.linspace( |
| 1.0, |
| 2.0 ** (n_harmonic_functions - 1), |
| n_harmonic_functions, |
| dtype=torch.float32, |
| ) |
|
|
| self.register_buffer("_frequencies", frequencies * omega_0, persistent=False) |
| self.register_buffer( |
| "_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False |
| ) |
| self.append_input = append_input |
|
|
| def forward( |
| self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: tensor of shape [..., dim] |
| diag_cov: An optional tensor of shape `(..., dim)` |
| representing the diagonal covariance matrices of our Gaussians, joined with x |
| as means of the Gaussians. |
| |
| Returns: |
| embedding: a harmonic embedding of `x` of shape |
| [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray] |
| """ |
| |
| embed = x[..., None] * self._frequencies |
| |
| embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None] |
| |
| |
| embed = embed.sin() |
| if diag_cov is not None: |
| x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2) |
| exp_var = torch.exp(-0.5 * x_var) |
| |
| embed = embed * exp_var[..., None, :, :] |
|
|
| embed = embed.reshape(*x.shape[:-1], -1) |
|
|
| if self.append_input: |
| return torch.cat([embed, x], dim=-1) |
| return embed |
|
|
| @staticmethod |
| def get_output_dim_static( |
| input_dims: int, |
| n_harmonic_functions: int, |
| append_input: bool, |
| ) -> int: |
| """ |
| Utility to help predict the shape of the output of `forward`. |
| |
| Args: |
| input_dims: length of the last dimension of the input tensor |
| n_harmonic_functions: number of embedding frequencies |
| append_input: whether or not to concat the original |
| input to the harmonic embedding |
| Returns: |
| int: the length of the last dimension of the output tensor |
| """ |
| return input_dims * (2 * n_harmonic_functions + int(append_input)) |
|
|
| def get_output_dim(self, input_dims: int = 3) -> int: |
| """ |
| Same as above. The default for input_dims is 3 for 3D applications |
| which use harmonic embedding for positional encoding, |
| so the input might be xyz. |
| """ |
| return self.get_output_dim_static( |
| input_dims, len(self._frequencies), self.append_input |
| ) |
|
|