Spaces:
Paused
Paused
| """FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/.""" | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| def kaiming_leaky_init(m): | |
| classname = m.__class__.__name__ | |
| if classname.find("Linear") != -1: | |
| torch.nn.init.kaiming_normal_( | |
| m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu" | |
| ) | |
| def frequency_init(freq): | |
| def init(m): | |
| with torch.no_grad(): | |
| if isinstance(m, nn.Linear): | |
| num_input = m.weight.size(-1) | |
| m.weight.uniform_( | |
| -np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq | |
| ) | |
| return init | |
| def first_layer_film_sine_init(m): | |
| with torch.no_grad(): | |
| if isinstance(m, nn.Linear): | |
| num_input = m.weight.size(-1) | |
| m.weight.uniform_(-1 / num_input, 1 / num_input) | |
| class CustomMappingNetwork(nn.Module): | |
| def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim): | |
| super().__init__() | |
| self.network = [] | |
| for _ in range(map_hidden_layers): | |
| self.network.append(nn.Linear(in_features, map_hidden_dim)) | |
| self.network.append(nn.LeakyReLU(0.2, inplace=True)) | |
| in_features = map_hidden_dim | |
| self.network.append(nn.Linear(map_hidden_dim, map_output_dim)) | |
| self.network = nn.Sequential(*self.network) | |
| self.network.apply(kaiming_leaky_init) | |
| with torch.no_grad(): | |
| self.network[-1].weight *= 0.25 | |
| def forward(self, z): | |
| frequencies_offsets = self.network(z) | |
| frequencies = frequencies_offsets[ | |
| ..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") | |
| ] | |
| phase_shifts = frequencies_offsets[ | |
| ..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") : | |
| ] | |
| return frequencies, phase_shifts | |
| class FiLMLayer(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.layer = nn.Linear(input_dim, hidden_dim) | |
| def forward(self, x, freq, phase_shift): | |
| x = self.layer(x) | |
| freq = freq.expand_as(x) | |
| phase_shift = phase_shift.expand_as(x) | |
| return torch.sin(freq * x + phase_shift) | |
| class FiLMSiren(nn.Module): | |
| """FiLM Conditioned Siren network.""" | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| hidden_layers: int, | |
| hidden_features: int, | |
| mapping_network_in_dim: int, | |
| mapping_network_layers: int, | |
| mapping_network_features: int, | |
| out_dim: int, | |
| outermost_linear: bool = False, | |
| out_activation: Optional[nn.Module] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.in_dim = in_dim | |
| assert self.in_dim > 0 | |
| self.out_dim = out_dim if out_dim is not None else hidden_features | |
| self.hidden_layers = hidden_layers | |
| self.hidden_features = hidden_features | |
| self.mapping_network_in_dim = mapping_network_in_dim | |
| self.mapping_network_layers = mapping_network_layers | |
| self.mapping_network_features = mapping_network_features | |
| self.outermost_linear = outermost_linear | |
| self.out_activation = out_activation | |
| self.net = nn.ModuleList() | |
| self.net.append(FiLMLayer(self.in_dim, self.hidden_features)) | |
| for _ in range(self.hidden_layers - 1): | |
| self.net.append(FiLMLayer(self.hidden_features, self.hidden_features)) | |
| self.final_layer = None | |
| if self.outermost_linear: | |
| self.final_layer = nn.Linear(self.hidden_features, self.out_dim) | |
| self.final_layer.apply(frequency_init(25)) | |
| else: | |
| final_layer = FiLMLayer(self.hidden_features, self.out_dim) | |
| self.net.append(final_layer) | |
| self.mapping_network = CustomMappingNetwork( | |
| in_features=self.mapping_network_in_dim, | |
| map_hidden_layers=self.mapping_network_layers, | |
| map_hidden_dim=self.mapping_network_features, | |
| map_output_dim=(len(self.net)) * self.hidden_features * 2, | |
| ) | |
| self.net.apply(frequency_init(25)) | |
| self.net[0].apply(first_layer_film_sine_init) | |
| def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts): | |
| """Get conditiional frequencies and phase shifts from mapping network.""" | |
| frequencies = frequencies * 15 + 30 | |
| for index, layer in enumerate(self.net): | |
| start = index * self.hidden_features | |
| end = (index + 1) * self.hidden_features | |
| x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) | |
| x = self.final_layer(x) if self.final_layer is not None else x | |
| output = self.out_activation(x) if self.out_activation is not None else x | |
| return output | |
| def forward(self, x, conditioning_input): | |
| """Forward pass.""" | |
| frequencies, phase_shifts = self.mapping_network(conditioning_input) | |
| return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts) | |