NeAR / trellis /models /structured_latent_vae /nerf_encoding.py
luh1124's picture
restore: full Space tree + assets (recover from minimal force-push); keep ZeroGPU app.py
0d1388f
# Modified from https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/field_components/encodings.py
# Copyright 2022 The Nerfstudio Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from typing import Optional
Tensor = torch.Tensor
class NeRFEncoding(nn.Module):
"""Multi-scale sinusoidal encodings.
Each axis is encoded with frequencies ranging from 2^min_freq_exp to 2^max_freq_exp.
Args:
in_dim: Input dimension of tensor
num_frequencies: Number of encoded frequencies per axis
min_freq_exp: Minimum frequency exponent
max_freq_exp: Maximum frequency exponent
include_input: Append the input coordinate to the encoding
"""
def __init__(
self,
in_dim: int,
num_frequencies: int,
min_freq_exp: float = 0.,
max_freq_exp: Optional[float] = None,
include_input: bool = False
) -> None:
super().__init__()
if max_freq_exp is None:
max_freq_exp = num_frequencies - 1
self.in_dim = in_dim
self.num_frequencies = num_frequencies
self.min_freq = min_freq_exp
self.max_freq = max_freq_exp
self.include_input = include_input
def get_out_dim(self) -> int:
if self.in_dim is None:
raise ValueError("Input dimension has not been set")
out_dim = self.in_dim * self.num_frequencies * 2
if self.include_input:
out_dim += self.in_dim
return out_dim
def forward(
self,
in_tensor: Tensor
) -> Tensor:
"""Calculates NeRF encoding. If covariances are provided the encodings will be integrated as proposed
in mip-NeRF.
Args:
in_tensor: For best performance, the input tensor should be between 0 and 1. [*bs, input_dim]
Returns:
Output values will be between -1 and 1. [*bs, output_dim]
"""
freqs = 2 ** torch.linspace(self.min_freq, self.max_freq, self.num_frequencies, device=in_tensor.device)
scaled_inputs = in_tensor[..., None] * freqs # [..., "input_dim", "num_scales"]
scaled_inputs = scaled_inputs.view(*scaled_inputs.shape[:-2], -1) # [..., "input_dim" * "num_scales"]
encoded_inputs = torch.sin(torch.cat([scaled_inputs, scaled_inputs + torch.pi / 2.0], dim=-1))
if self.include_input:
encoded_inputs = torch.cat([in_tensor, encoded_inputs], dim=-1)
return encoded_inputs