# Modified from https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/field_components/encodings.py # Copyright 2022 The Nerfstudio Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import nn from typing import Optional Tensor = torch.Tensor class NeRFEncoding(nn.Module): """Multi-scale sinusoidal encodings. Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp. Args: in_dim: Input dimension of tensor num_frequencies: Number of encoded frequencies per axis min_freq_exp: Minimum frequency exponent max_freq_exp: Maximum frequency exponent include_input: Append the input coordinate to the encoding """ def __init__( self, in_dim: int, num_frequencies: int, min_freq_exp: float = 0., max_freq_exp: Optional[float] = None, include_input: bool = False ) -> None: super().__init__() if max_freq_exp is None: max_freq_exp = num_frequencies - 1 self.in_dim = in_dim self.num_frequencies = num_frequencies self.min_freq = min_freq_exp self.max_freq = max_freq_exp self.include_input = include_input def get_out_dim(self) -> int: if self.in_dim is None: raise ValueError("Input dimension has not been set") out_dim = self.in_dim * self.num_frequencies * 2 if self.include_input: out_dim += self.in_dim return out_dim def forward( self, in_tensor: Tensor ) -> Tensor: """Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed in mip-NeRF. Args: in_tensor: For best performance, the input tensor should be between 0 and 1. [*bs, input_dim] Returns: Output values will be between -1 and 1. [*bs, output_dim] """ freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device) scaled_inputs = in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"] scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"] encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1)) if self.include_input: encoded_inputs = torch.cat([in_tensor, encoded_inputs], dim=-1) return encoded_inputs