Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ..utils.general_utils import config_to_primitive | |
| from dataclasses import dataclass | |
| def get_activation(name): | |
| if name is None: | |
| return lambda x: x | |
| name = name.lower() | |
| if name == "none": | |
| return lambda x: x | |
| elif name == "sigmoid-mipnerf": | |
| return lambda x: torch.sigmoid(x) * (1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF | |
| else: | |
| try: | |
| return getattr(F, name) | |
| except AttributeError: | |
| raise ValueError(f"Unknown activation function: {name}") | |
| class VanillaMLP(nn.Module): | |
| def __init__(self, dim_in, dim_out, config): | |
| super().__init__() | |
| # Convert dict to MLPConfig if needed | |
| if isinstance(config, dict): | |
| config = MLPConfig(**config) | |
| self.n_neurons = config.n_neurons | |
| self.n_hidden_layers = config.n_hidden_layers | |
| layers = [ | |
| self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), | |
| self.make_activation(), | |
| ] | |
| for i in range(self.n_hidden_layers - 1): | |
| layers += [ | |
| self.make_linear( | |
| self.n_neurons, self.n_neurons, is_first=False, is_last=False | |
| ), | |
| self.make_activation(), | |
| ] | |
| layers += [ | |
| self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True) | |
| ] | |
| self.layers = nn.Sequential(*layers) | |
| self.output_activation = get_activation(config.output_activation) | |
| def forward(self, x): | |
| # disable autocast | |
| # strange that the parameters will have empty gradients if autocast is enabled in AMP | |
| with torch.cuda.amp.autocast(enabled=False): | |
| x = self.layers(x) | |
| x = self.output_activation(x) | |
| return x | |
| def make_linear(self, dim_in, dim_out, is_first, is_last): | |
| layer = nn.Linear(dim_in, dim_out, bias=False) | |
| return layer | |
| def make_activation(self): | |
| return nn.ReLU(inplace=True) | |
| class MLPConfig: | |
| otype: str = "VanillaMLP" | |
| activation: str = "ReLU" | |
| output_activation: str = "none" | |
| n_neurons: int = 64 | |
| n_hidden_layers: int = 2 | |
| def get_mlp(input_dim, output_dim, config): | |
| """Create MLP network based on config""" | |
| # Convert dict to MLPConfig | |
| if isinstance(config, dict): | |
| config = MLPConfig(**config) | |
| if config.otype == "VanillaMLP": | |
| network = VanillaMLP(input_dim, output_dim, config) | |
| else: | |
| raise ValueError(f"Unknown MLP type: {config.otype}") | |
| return network |