| |
| """Contains the implementation of generator described in VolumeGAN. |
| |
| Paper: https://arxiv.org/pdf/2112.10759.pdf |
| """ |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from .stylegan2_generator import MappingNetwork |
| from .stylegan2_generator import ModulateConvLayer |
| from .stylegan2_generator import ConvLayer |
| from .stylegan2_generator import DenseLayer |
| from third_party.stylegan2_official_ops import upfirdn2d |
| from .rendering import Renderer |
| from .rendering import FeatureExtractor |
| from .utils.ops import all_gather |
|
|
|
|
| class VolumeGANGenerator(nn.Module): |
| """Defines the generator network in VoumeGAN.""" |
|
|
| def __init__( |
| self, |
| |
| z_dim=512, |
| w_dim=512, |
| repeat_w=True, |
| normalize_z=True, |
| mapping_layers=8, |
| mapping_fmaps=512, |
| mapping_use_wscale=True, |
| mapping_wscale_gain=1.0, |
| mapping_lr_mul=0.01, |
| |
| label_dim=0, |
| embedding_dim=512, |
| embedding_bias=True, |
| embedding_use_wscale=True, |
| embedding_wscale_gian=1.0, |
| embedding_lr_mul=1.0, |
| normalize_embedding=True, |
| normalize_embedding_latent=False, |
| |
| resolution=-1, |
| nerf_res=32, |
| image_channels=3, |
| final_tanh=False, |
| demodulate=True, |
| use_wscale=True, |
| wscale_gain=1.0, |
| lr_mul=1.0, |
| noise_type='spatial', |
| fmaps_base=32 << 10, |
| fmaps_max=512, |
| filter_kernel=(1, 3, 3, 1), |
| conv_clamp=None, |
| eps=1e-8, |
| rgb_init_res_out=True, |
| |
| fv_cfg=dict(feat_res=32, |
| init_res=4, |
| base_channels=256, |
| output_channels=32, |
| w_dim=512), |
| |
| embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10), |
| |
| fg_cfg=dict(num_layers=4, hidden_dim=256, activation_type='lrelu'), |
| bg_cfg=None, |
| out_dim=512, |
| |
| rendering_kwargs={}): |
|
|
| super().__init__() |
|
|
| self.z_dim = z_dim |
| self.w_dim = w_dim |
| self.repeat_w = repeat_w |
| self.normalize_z = normalize_z |
| self.mapping_layers = mapping_layers |
| self.mapping_fmaps = mapping_fmaps |
| self.mapping_use_wscale = mapping_use_wscale |
| self.mapping_wscale_gain = mapping_wscale_gain |
| self.mapping_lr_mul = mapping_lr_mul |
|
|
| self.latent_dim = (z_dim,) |
| self.label_size = label_dim |
| self.label_dim = label_dim |
| self.embedding_dim = embedding_dim |
| self.embedding_bias = embedding_bias |
| self.embedding_use_wscale = embedding_use_wscale |
| self.embedding_wscale_gain = embedding_wscale_gian |
| self.embedding_lr_mul = embedding_lr_mul |
| self.normalize_embedding = normalize_embedding |
| self.normalize_embedding_latent = normalize_embedding_latent |
|
|
| self.resolution = resolution |
| self.nerf_res = nerf_res |
| self.image_channels = image_channels |
| self.final_tanh = final_tanh |
| self.demodulate = demodulate |
| self.use_wscale = use_wscale |
| self.wscale_gain = wscale_gain |
| self.lr_mul = lr_mul |
| self.noise_type = noise_type.lower() |
| self.fmaps_base = fmaps_base |
| self.fmaps_max = fmaps_max |
| self.filter_kernel = filter_kernel |
| self.conv_clamp = conv_clamp |
| self.eps = eps |
|
|
| self.num_nerf_layers = fg_cfg['num_layers'] |
| self.num_cnn_layers = int(np.log2(resolution // nerf_res * 2)) * 2 |
| self.num_layers = self.num_nerf_layers + self.num_cnn_layers |
|
|
| |
| if self.repeat_w: |
| self.register_buffer('w_avg', torch.zeros(w_dim)) |
| else: |
| self.register_buffer('w_avg', torch.zeros(self.num_layers * w_dim)) |
|
|
| |
| self.mapping = MappingNetwork( |
| input_dim=z_dim, |
| output_dim=w_dim, |
| num_outputs=self.num_layers, |
| repeat_output=repeat_w, |
| normalize_input=normalize_z, |
| num_layers=mapping_layers, |
| hidden_dim=mapping_fmaps, |
| use_wscale=mapping_use_wscale, |
| wscale_gain=mapping_wscale_gain, |
| lr_mul=mapping_lr_mul, |
| label_dim=label_dim, |
| embedding_dim=embedding_dim, |
| embedding_bias=embedding_bias, |
| embedding_use_wscale=embedding_use_wscale, |
| embedding_wscale_gian=embedding_wscale_gian, |
| embedding_lr_mul=embedding_lr_mul, |
| normalize_embedding=normalize_embedding, |
| normalize_embedding_latent=normalize_embedding_latent, |
| eps=eps) |
|
|
| |
| self.renderer = Renderer() |
|
|
| |
| self.ref_representation_generator = FeatureVolume(**fv_cfg) |
|
|
| |
| self.position_encoder = PositionEncoder(**embed_cfg) |
|
|
| |
| self.feature_extractor = FeatureExtractor(ref_mode='feature_volume') |
|
|
| |
| self.post_module = NeRFMLPNetwork(input_dim=self.position_encoder.out_dim + |
| fv_cfg['output_channels'], |
| fg_cfg=fg_cfg, |
| bg_cfg=bg_cfg) |
|
|
| |
| self.fc_head = FCHead(fg_cfg=fg_cfg, bg_cfg=bg_cfg, out_dim=out_dim) |
|
|
| |
| self.post_neural_renderer = PostNeuralRendererNetwork( |
| resolution=resolution, |
| init_res=nerf_res, |
| w_dim=w_dim, |
| image_channels=image_channels, |
| final_tanh=final_tanh, |
| demodulate=demodulate, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| noise_type=noise_type, |
| fmaps_base=fmaps_base, |
| filter_kernel=filter_kernel, |
| fmaps_max=fmaps_max, |
| conv_clamp=conv_clamp, |
| eps=eps, |
| rgb_init_res_out=rgb_init_res_out) |
|
|
| |
| self.rendering_kwargs = rendering_kwargs |
|
|
| |
| |
| self.cur_to_official_part_mapping = { |
| 'w_avg': 'w_avg', |
| 'mapping': 'mapping', |
| 'ref_representation_generator': 'nerfmlp.fv', |
| 'post_module.fg_mlp': 'nerfmlp.fg_mlps', |
| 'fc_head.fg_sigma_head': 'nerfmlp.fg_density', |
| 'fc_head.fg_rgb_head': 'nerfmlp.fg_color', |
| 'post_neural_renderer': 'synthesis' |
| } |
|
|
| |
| if self.rendering_kwargs.get('debug_mode', False): |
| self.set_weights_from_official( |
| rendering_kwargs.get('cur_state', None), |
| rendering_kwargs.get('official_state', None)) |
|
|
| def get_cur_to_official_full_mapping(self, keys_cur): |
| cur_to_official_full_mapping = {} |
| for key, val in self.cur_to_official_part_mapping.items(): |
| for key_cur_full in keys_cur: |
| if key in key_cur_full: |
| sub_key = key_cur_full.replace(key, '') |
| cur_to_official_full_mapping[key + sub_key] = val + sub_key |
| return cur_to_official_full_mapping |
|
|
| def set_weights_from_official(self, cur_state, official_state): |
| keys_cur = cur_state['models']['generator_smooth'].keys() |
| self.cur_to_official_full_mapping = ( |
| self.get_cur_to_official_full_mapping(keys_cur)) |
| for name, param in self.named_parameters(): |
| param.data = (official_state['models']['generator_smooth'][ |
| self.cur_to_official_full_mapping[name]]) |
|
|
| def forward( |
| self, |
| z, |
| label=None, |
| lod=None, |
| w_moving_decay=None, |
| sync_w_avg=False, |
| style_mixing_prob=None, |
| trunc_psi=None, |
| trunc_layers=None, |
| noise_mode='const', |
| fused_modulate=False, |
| impl='cuda', |
| fp16_res=None, |
| ): |
| mapping_results = self.mapping(z, label, impl=impl) |
| w = mapping_results['w'] |
| lod = self.post_neural_renderer.lod.item() if lod is None else lod |
|
|
| if self.training and w_moving_decay is not None: |
| if sync_w_avg: |
| batch_w_avg = all_gather(w.detach()).mean(dim=0) |
| else: |
| batch_w_avg = w.detach().mean(dim=0) |
| self.w_avg.copy_(batch_w_avg.lerp(self.w_avg, w_moving_decay)) |
|
|
| wp = mapping_results['wp'] |
|
|
| if self.training and style_mixing_prob is not None: |
| if np.random.uniform() < style_mixing_prob: |
| new_z = torch.randn_like(z) |
| new_wp = self.mapping(new_z, label, impl=impl)['wp'] |
| current_layers = self.num_layers |
| if current_layers > self.num_nerf_layers: |
| mixing_cutoff = np.random.randint(self.num_nerf_layers, |
| current_layers) |
| wp[:, mixing_cutoff:] = new_wp[:, mixing_cutoff:] |
|
|
| if not self.training: |
| trunc_psi = 1.0 if trunc_psi is None else trunc_psi |
| trunc_layers = 0 if trunc_layers is None else trunc_layers |
| if trunc_psi < 1.0 and trunc_layers > 0: |
| w_avg = self.w_avg.reshape(1, -1, self.w_dim)[:, :trunc_layers] |
| wp[:, :trunc_layers] = w_avg.lerp( |
| wp[:, :trunc_layers], trunc_psi) |
|
|
| nerf_w = wp[:,:self.num_nerf_layers] |
| cnn_w = wp[:,self.num_nerf_layers:] |
|
|
| feature_volume = self.ref_representation_generator(nerf_w) |
|
|
| rendering_results = self.renderer( |
| wp=nerf_w, |
| feature_extractor=self.feature_extractor, |
| rendering_options=self.rendering_kwargs, |
| position_encoder=self.position_encoder, |
| ref_representation=feature_volume, |
| post_module=self.post_module, |
| fc_head=self.fc_head) |
|
|
| feature2d = rendering_results['composite_rgb'] |
| feature2d = feature2d.reshape(feature2d.shape[0], self.nerf_res, |
| self.nerf_res, -1).permute(0, 3, 1, 2) |
|
|
| final_results = self.post_neural_renderer( |
| feature2d, |
| cnn_w, |
| lod=None, |
| noise_mode=noise_mode, |
| fused_modulate=fused_modulate, |
| impl=impl, |
| fp16_res=fp16_res) |
|
|
| return {**mapping_results, **final_results} |
|
|
|
|
| class PositionEncoder(nn.Module): |
| """Implements the class for positional encoding.""" |
|
|
| def __init__(self, |
| input_dim, |
| max_freq_log2, |
| N_freqs, |
| log_sampling=True, |
| include_input=True, |
| periodic_fns=(torch.sin, torch.cos)): |
| """Initializes with basic settings. |
| |
| Args: |
| input_dim: Dimension of input to be embedded. |
| max_freq_log2: `log2` of max freq; min freq is 1 by default. |
| N_freqs: Number of frequency bands. |
| log_sampling: If True, frequency bands are linerly sampled in |
| log-space. |
| include_input: If True, raw input is included in the embedding. |
| Defaults to True. |
| periodic_fns: Periodic functions used to embed input. |
| Defaults to (torch.sin, torch.cos). |
| """ |
| super().__init__() |
|
|
| self.input_dim = input_dim |
| self.include_input = include_input |
| self.periodic_fns = periodic_fns |
|
|
| self.out_dim = 0 |
| if self.include_input: |
| self.out_dim += self.input_dim |
|
|
| self.out_dim += self.input_dim * N_freqs * len(self.periodic_fns) |
|
|
| if log_sampling: |
| self.freq_bands = 2.**torch.linspace(0., max_freq_log2, N_freqs) |
| else: |
| self.freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, |
| N_freqs) |
|
|
| self.freq_bands = self.freq_bands.numpy().tolist() |
|
|
| def forward(self, input): |
| assert (input.shape[-1] == self.input_dim) |
|
|
| out = [] |
| if self.include_input: |
| out.append(input) |
|
|
| for i in range(len(self.freq_bands)): |
| freq = self.freq_bands[i] |
| for p_fn in self.periodic_fns: |
| out.append(p_fn(input * freq)) |
| out = torch.cat(out, dim=-1) |
|
|
| assert (out.shape[-1] == self.out_dim) |
|
|
| return out |
|
|
|
|
| class FeatureVolume(nn.Module): |
| """Defines feature volume in VolumeGAN.""" |
|
|
| def __init__(self, |
| feat_res=32, |
| init_res=4, |
| base_channels=256, |
| output_channels=32, |
| w_dim=512, |
| **kwargs): |
| super().__init__() |
| self.num_stages = int(np.log2(feat_res // init_res)) + 1 |
|
|
| self.const = nn.Parameter( |
| torch.ones(1, base_channels, init_res, init_res, init_res)) |
| inplanes = base_channels |
| outplanes = base_channels |
|
|
| self.stage_channels = [] |
| for i in range(self.num_stages): |
| conv = nn.Conv3d(inplanes, |
| outplanes, |
| kernel_size=(3, 3, 3), |
| padding=(1, 1, 1)) |
| self.stage_channels.append(outplanes) |
| self.add_module(f'layer{i}', conv) |
| instance_norm = InstanceNormLayer(num_features=outplanes, |
| affine=False) |
|
|
| self.add_module(f'instance_norm{i}', instance_norm) |
| inplanes = outplanes |
| outplanes = max(outplanes // 2, output_channels) |
| if i == self.num_stages - 1: |
| outplanes = output_channels |
|
|
| self.mapping_network = nn.Linear(w_dim, sum(self.stage_channels) * 2) |
| self.mapping_network.apply(kaiming_leaky_init) |
| with torch.no_grad(): |
| self.mapping_network.weight *= 0.25 |
| self.upsample = UpsamplingLayer() |
| self.lrelu = nn.LeakyReLU(negative_slope=0.2) |
|
|
| def forward(self, w, **kwargs): |
| if w.ndim == 3: |
| _w = w[:, 0] |
| else: |
| _w = w |
| scale_shifts = self.mapping_network(_w) |
| scales = scale_shifts[..., :scale_shifts.shape[-1] // 2] |
| shifts = scale_shifts[..., scale_shifts.shape[-1] // 2:] |
|
|
| x = self.const.repeat(w.shape[0], 1, 1, 1, 1) |
| for idx in range(self.num_stages): |
| if idx != 0: |
| x = self.upsample(x) |
| conv_layer = self.__getattr__(f'layer{idx}') |
| x = conv_layer(x) |
| instance_norm = self.__getattr__(f'instance_norm{idx}') |
| scale = scales[:, |
| sum(self.stage_channels[:idx] |
| ):sum(self.stage_channels[:idx + 1])] |
| shift = shifts[:, |
| sum(self.stage_channels[:idx] |
| ):sum(self.stage_channels[:idx + 1])] |
| scale = scale.view(scale.shape + (1, 1, 1)) |
| shift = shift.view(shift.shape + (1, 1, 1)) |
| x = instance_norm(x, weight=scale, bias=shift) |
| x = self.lrelu(x) |
|
|
| return x |
|
|
|
|
| def kaiming_leaky_init(m): |
| classname = m.__class__.__name__ |
| if classname.find('Linear') != -1: |
| torch.nn.init.kaiming_normal_(m.weight, |
| a=0.2, |
| mode='fan_in', |
| nonlinearity='leaky_relu') |
|
|
|
|
| class InstanceNormLayer(nn.Module): |
| """Implements instance normalization layer.""" |
|
|
| def __init__(self, num_features, epsilon=1e-8, affine=False): |
| super().__init__() |
| self.eps = epsilon |
| self.affine = affine |
| if self.affine: |
| self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
| self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1, 1)) |
| self.weight.data.uniform_() |
| self.bias.data.zero_() |
|
|
| def forward(self, x, weight=None, bias=None): |
| x = x - torch.mean(x, dim=[2, 3, 4], keepdim=True) |
| norm = torch.sqrt( |
| torch.mean(x**2, dim=[2, 3, 4], keepdim=True) + self.eps) |
| x = x / norm |
| isnot_input_none = weight is not None and bias is not None |
| assert (isnot_input_none and not self.affine) or (not isnot_input_none |
| and self.affine) |
| if self.affine: |
| x = x * self.weight + self.bias |
| else: |
| x = x * weight + bias |
| return x |
|
|
|
|
| class UpsamplingLayer(nn.Module): |
|
|
| def __init__(self, scale_factor=2): |
| super().__init__() |
| self.scale_factor = scale_factor |
|
|
| def forward(self, x): |
| if self.scale_factor <= 1: |
| return x |
| return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') |
|
|
|
|
| class NeRFMLPNetwork(nn.Module): |
| """Defines class of MLP Network described in VolumeGAN. |
| |
| Basically, this class takes in latent codes and point coodinates as input, |
| and outputs features of each point, which is followed by two fully-connected |
| layer heads. |
| """ |
|
|
| def __init__(self, input_dim, fg_cfg, bg_cfg=None): |
| super().__init__() |
| self.fg_mlp = self.build_mlp(input_dim=input_dim, **fg_cfg) |
|
|
| def build_mlp(self, input_dim, num_layers, hidden_dim, activation_type, |
| **kwargs): |
| """Implements function to build the `MLP`. |
| |
| Note that here the `MLP` network is consists of a series of |
| `ModulateConvLayer` with `kernel_size=1` to simulate fully-connected |
| layer. Typically, the input's shape of convolutional layers is |
| `[N, C, H, W]`. And the input's shape is `[N, C, R*K, 1]` here, which |
| aims to keep consistent with `MLP`. |
| """ |
| default_conv_cfg = dict(resolution=32, |
| w_dim=512, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| demodulate=True, |
| use_wscale=True, |
| wscale_gain=1, |
| lr_mul=1, |
| noise_type='none', |
| conv_clamp=None, |
| eps=1e-8) |
| mlp_list = nn.ModuleList() |
| in_ch = input_dim |
| out_ch = hidden_dim |
| for _ in range(num_layers): |
| mlp = ModulateConvLayer(in_channels=in_ch, |
| out_channels=out_ch, |
| activation_type=activation_type, |
| **default_conv_cfg) |
| mlp_list.append(mlp) |
| in_ch = out_ch |
| out_ch = hidden_dim |
|
|
| return mlp_list |
|
|
| def forward(self, |
| pre_point_features, |
| wp, |
| points_encoding=None, |
| fused_modulate=False, |
| impl='cuda'): |
| N, C, R_K, _ = points_encoding.shape |
| x = torch.cat([pre_point_features, points_encoding], dim=1) |
|
|
| for idx, mlp in enumerate(self.fg_mlp): |
| if wp.ndim == 3: |
| _w = wp[:, idx] |
| else: |
| _w = wp |
| x, _ = mlp(x, _w, fused_modulate=fused_modulate, impl=impl) |
|
|
| return x |
|
|
|
|
| class FCHead(nn.Module): |
| """Defines fully-connected layer head in VolumeGAN to decode `feature` into |
| `sigma` and `rgb`.""" |
|
|
| def __init__(self, fg_cfg, bg_cfg=None, out_dim=512): |
| super().__init__() |
| self.fg_sigma_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
| out_channels=1, |
| add_bias=True, |
| init_bias=0.0, |
| use_wscale=True, |
| wscale_gain=1, |
| lr_mul=1, |
| activation_type='linear') |
| self.fg_rgb_head = DenseLayer(in_channels=fg_cfg['hidden_dim'], |
| out_channels=out_dim, |
| add_bias=True, |
| init_bias=0.0, |
| use_wscale=True, |
| wscale_gain=1, |
| lr_mul=1, |
| activation_type='linear') |
|
|
| def forward(self, post_point_features, wp=None, dirs=None): |
| post_point_features = rearrange( |
| post_point_features, 'N C (R_K) 1 -> (N R_K) C').contiguous() |
| fg_sigma = self.fg_sigma_head(post_point_features) |
| fg_rgb = self.fg_rgb_head(post_point_features) |
|
|
| results = {'sigma': fg_sigma, 'rgb': fg_rgb} |
|
|
| return results |
|
|
|
|
| class PostNeuralRendererNetwork(nn.Module): |
| """Implements the neural renderer in VolumeGAN to render high-resolution |
| images. |
| |
| Basically, this network executes several convolutional layers in sequence. |
| """ |
|
|
| def __init__( |
| self, |
| resolution, |
| init_res, |
| w_dim, |
| image_channels, |
| final_tanh, |
| demodulate, |
| use_wscale, |
| wscale_gain, |
| lr_mul, |
| noise_type, |
| fmaps_base, |
| fmaps_max, |
| filter_kernel, |
| conv_clamp, |
| eps, |
| rgb_init_res_out=False, |
| ): |
| super().__init__() |
|
|
| self.init_res = init_res |
| self.init_res_log2 = int(np.log2(init_res)) |
| self.resolution = resolution |
| self.final_res_log2 = int(np.log2(resolution)) |
| self.w_dim = w_dim |
| self.image_channels = image_channels |
| self.final_tanh = final_tanh |
| self.demodulate = demodulate |
| self.use_wscale = use_wscale |
| self.wscale_gain = wscale_gain |
| self.lr_mul = lr_mul |
| self.noise_type = noise_type.lower() |
| self.fmaps_base = fmaps_base |
| self.fmaps_max = fmaps_max |
| self.filter_kernel = filter_kernel |
| self.conv_clamp = conv_clamp |
| self.eps = eps |
| self.rgb_init_res_out = rgb_init_res_out |
|
|
| self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2 |
|
|
| self.register_buffer('lod', torch.zeros(())) |
|
|
| for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
| res = 2**res_log2 |
| in_channels = self.get_nf(res // 2) |
| out_channels = self.get_nf(res) |
| block_idx = res_log2 - self.init_res_log2 |
|
|
| |
| if res > init_res: |
| layer_name = f'layer{2 * block_idx - 1}' |
| self.add_module( |
| layer_name, |
| ModulateConvLayer(in_channels=in_channels, |
| out_channels=out_channels, |
| resolution=res, |
| w_dim=w_dim, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=2, |
| filter_kernel=filter_kernel, |
| demodulate=demodulate, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| noise_type=noise_type, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp, |
| eps=eps)) |
| if block_idx == 0: |
| if self.rgb_init_res_out: |
| self.rgb_init_res = ConvLayer( |
| in_channels=out_channels, |
| out_channels=image_channels, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| activation_type='linear', |
| conv_clamp=conv_clamp, |
| ) |
| continue |
| |
| layer_name = f'layer{2 * block_idx}' |
| self.add_module( |
| layer_name, |
| ModulateConvLayer(in_channels=out_channels, |
| out_channels=out_channels, |
| resolution=res, |
| w_dim=w_dim, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| demodulate=demodulate, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| noise_type=noise_type, |
| activation_type='lrelu', |
| conv_clamp=conv_clamp, |
| eps=eps)) |
|
|
| |
| layer_name = f'output{block_idx}' |
| self.add_module( |
| layer_name, |
| ModulateConvLayer(in_channels=out_channels, |
| out_channels=image_channels, |
| resolution=res, |
| w_dim=w_dim, |
| kernel_size=1, |
| add_bias=True, |
| scale_factor=1, |
| filter_kernel=None, |
| demodulate=False, |
| use_wscale=use_wscale, |
| wscale_gain=wscale_gain, |
| lr_mul=lr_mul, |
| noise_type='none', |
| activation_type='linear', |
| conv_clamp=conv_clamp, |
| eps=eps)) |
|
|
| |
| self.register_buffer('filter', upfirdn2d.setup_filter(filter_kernel)) |
|
|
| def get_nf(self, res): |
| """Gets number of feature maps according to current resolution.""" |
| return min(self.fmaps_base // res, self.fmaps_max) |
|
|
| def set_space_of_latent(self, space_of_latent): |
| """Sets the space to which the latent code belong. |
| |
| Args: |
| space_of_latent: The space to which the latent code belong. Case |
| insensitive. Support `W` and `Y`. |
| """ |
| space_of_latent = space_of_latent.upper() |
| for module in self.modules(): |
| if isinstance(module, ModulateConvLayer): |
| setattr(module, 'space_of_latent', space_of_latent) |
|
|
| def forward(self, |
| x, |
| wp, |
| lod=None, |
| noise_mode='const', |
| fused_modulate=False, |
| impl='cuda', |
| fp16_res=None, |
| nerf_out=False): |
| lod = self.lod.item() if lod is None else lod |
|
|
| results = {} |
|
|
| |
| if fp16_res is not None and self.init_res >= fp16_res: |
| x = x.to(torch.float16) |
|
|
| for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1): |
| cur_lod = self.final_res_log2 - res_log2 |
| block_idx = res_log2 - self.init_res_log2 |
|
|
| layer_idxs = [2 * block_idx - 1, 2 * |
| block_idx] if block_idx > 0 else [ |
| 2 * block_idx, |
| ] |
| |
| if lod < cur_lod + 1: |
| for layer_idx in layer_idxs: |
| if layer_idx == 0: |
| |
| if self.rgb_init_res_out: |
| cur_image = self.rgb_init_res(x, |
| runtime_gain=1, |
| impl=impl) |
| else: |
| cur_image = x[:, :3] |
| continue |
| layer = getattr(self, f'layer{layer_idx}') |
| x, style = layer( |
| x, |
| wp[:, layer_idx], |
| noise_mode=noise_mode, |
| fused_modulate=fused_modulate, |
| impl=impl, |
| ) |
| results[f'style{layer_idx}'] = style |
| if layer_idx % 2 == 0: |
| output_layer = getattr(self, f'output{layer_idx // 2}') |
| y, style = output_layer( |
| x, |
| wp[:, layer_idx + 1], |
| fused_modulate=fused_modulate, |
| impl=impl, |
| ) |
| results[f'output_style{layer_idx // 2}'] = style |
| if layer_idx == 0: |
| cur_image = y.to(torch.float32) |
| else: |
| if not nerf_out: |
| cur_image = y.to( |
| torch.float32) + upfirdn2d.upsample2d( |
| cur_image, self.filter, impl=impl) |
| else: |
| cur_image = y.to(torch.float32) + cur_image |
|
|
| |
| if layer_idx != self.num_layers - 2: |
| res = self.init_res * (2**(layer_idx // 2)) |
| if fp16_res is not None and res * 2 >= fp16_res: |
| x = x.to(torch.float16) |
| else: |
| x = x.to(torch.float32) |
|
|
| |
| if cur_lod - 1 < lod <= cur_lod: |
| image = cur_image |
| elif cur_lod < lod < cur_lod + 1: |
| alpha = np.ceil(lod) - lod |
| image = F.interpolate(image, scale_factor=2, mode='nearest') |
| image = cur_image * alpha + image * (1 - alpha) |
| elif lod >= cur_lod + 1: |
| image = F.interpolate(image, scale_factor=2, mode='nearest') |
|
|
| if self.final_tanh: |
| image = torch.tanh(image) |
| results['image'] = image |
|
|
| return results |
|
|