| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| |
|
| | from ..utils import BaseModule |
| |
|
| |
|
| | class TriplaneUpsampleNetwork(BaseModule): |
| | @dataclass |
| | class Config(BaseModule.Config): |
| | in_channels: int |
| | out_channels: int |
| |
|
| | cfg: Config |
| |
|
| | def configure(self) -> None: |
| | self.upsample = nn.ConvTranspose2d( |
| | self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2 |
| | ) |
| |
|
| | def forward(self, triplanes: torch.Tensor) -> torch.Tensor: |
| | triplanes_up = rearrange( |
| | self.upsample( |
| | rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3) |
| | ), |
| | "(B Np) Co Hp Wp -> B Np Co Hp Wp", |
| | Np=3, |
| | ) |
| | return triplanes_up |
| |
|
| |
|
| | class NeRFMLP(BaseModule): |
| | @dataclass |
| | class Config(BaseModule.Config): |
| | in_channels: int |
| | n_neurons: int |
| | n_hidden_layers: int |
| | activation: str = "relu" |
| | bias: bool = True |
| | weight_init: Optional[str] = "kaiming_uniform" |
| | bias_init: Optional[str] = None |
| |
|
| | cfg: Config |
| |
|
| | def configure(self) -> None: |
| | layers = [ |
| | self.make_linear( |
| | self.cfg.in_channels, |
| | self.cfg.n_neurons, |
| | bias=self.cfg.bias, |
| | weight_init=self.cfg.weight_init, |
| | bias_init=self.cfg.bias_init, |
| | ), |
| | self.make_activation(self.cfg.activation), |
| | ] |
| | for i in range(self.cfg.n_hidden_layers - 1): |
| | layers += [ |
| | self.make_linear( |
| | self.cfg.n_neurons, |
| | self.cfg.n_neurons, |
| | bias=self.cfg.bias, |
| | weight_init=self.cfg.weight_init, |
| | bias_init=self.cfg.bias_init, |
| | ), |
| | self.make_activation(self.cfg.activation), |
| | ] |
| | layers += [ |
| | self.make_linear( |
| | self.cfg.n_neurons, |
| | 4, |
| | bias=self.cfg.bias, |
| | weight_init=self.cfg.weight_init, |
| | bias_init=self.cfg.bias_init, |
| | ) |
| | ] |
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def make_linear( |
| | self, |
| | dim_in, |
| | dim_out, |
| | bias=True, |
| | weight_init=None, |
| | bias_init=None, |
| | ): |
| | layer = nn.Linear(dim_in, dim_out, bias=bias) |
| |
|
| | if weight_init is None: |
| | pass |
| | elif weight_init == "kaiming_uniform": |
| | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu") |
| | else: |
| | raise NotImplementedError |
| |
|
| | if bias: |
| | if bias_init is None: |
| | pass |
| | elif bias_init == "zero": |
| | torch.nn.init.zeros_(layer.bias) |
| | else: |
| | raise NotImplementedError |
| |
|
| | return layer |
| |
|
| | def make_activation(self, activation): |
| | if activation == "relu": |
| | return nn.ReLU(inplace=True) |
| | elif activation == "silu": |
| | return nn.SiLU(inplace=True) |
| | else: |
| | raise NotImplementedError |
| |
|
| | def forward(self, x): |
| | inp_shape = x.shape[:-1] |
| | x = x.reshape(-1, x.shape[-1]) |
| |
|
| | features = self.layers(x) |
| | features = features.reshape(*inp_shape, -1) |
| | out = {"density": features[..., 0:1], "features": features[..., 1:4]} |
| |
|
| | return out |
| |
|