| | |
| | """Contains the implementation of generator described in BEV3D.""" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone |
| | from models.utils.official_stylegan2_model_helper import FullyConnectedLayer |
| | from models.utils.eg3d_superres import SuperresolutionHybrid2X |
| | from models.utils.eg3d_superres import SuperresolutionHybrid4X |
| | from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle |
| | from models.utils.eg3d_superres import SuperresolutionHybrid8XDC |
| | from models.rendering.renderer import Renderer |
| | from models.rendering.feature_extractor import FeatureExtractor |
| |
|
| | from models.utils.spade import SPADEGenerator |
| |
|
| | class BEV3DGenerator(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | z_dim, |
| | semantic_nc, |
| | ngf, |
| | bev_grid_size, |
| | aspect_ratio, |
| | num_upsampling_layers, |
| | not_use_vae, |
| | norm_G, |
| | img_resolution, |
| | interpolate_sr, |
| | segmask=False, |
| | dim_seq='16,8,4,2,1', |
| | xyz_pe=False, |
| | hidden_dim=64, |
| | additional_layer_num=0, |
| | sr_num_fp16_res=0, |
| | rendering_kwargs={}, |
| | sr_kwargs={}, |
| | ): |
| | super().__init__() |
| |
|
| | self.z_dim = z_dim |
| | self.interpolate_sr = interpolate_sr |
| | self.segmask = segmask |
| |
|
| | |
| | self.renderer = Renderer() |
| |
|
| | |
| | self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe) |
| |
|
| | |
| | self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size, |
| | aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers, |
| | not_use_vae=not_use_vae, norm_G=norm_G) |
| | print('backbone SPADEGenerator set up!') |
| |
|
| | |
| | self.post_module = None |
| |
|
| | |
| | self.post_neural_renderer = None |
| | sr_kwargs_total = dict( |
| | channels=32, |
| | img_resolution=img_resolution, |
| | sr_num_fp16_res=sr_num_fp16_res, |
| | sr_antialias=rendering_kwargs['sr_antialias'],) |
| | sr_kwargs_total.update(**sr_kwargs) |
| | if img_resolution == 128: |
| | self.post_neural_renderer = SuperresolutionHybrid2X( |
| | **sr_kwargs_total) |
| | elif img_resolution == 256: |
| | self.post_neural_renderer = SuperresolutionHybrid4X_conststyle( |
| | **sr_kwargs_total) |
| | elif img_resolution == 512: |
| | self.post_neural_renderer = SuperresolutionHybrid8XDC( |
| | **sr_kwargs_total) |
| | else: |
| | raise TypeError(f'Unsupported image resolution: {img_resolution}!') |
| |
|
| | |
| | self.fc_head = OSGDecoder( |
| | 128 if xyz_pe else 64 , { |
| | 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
| | 'decoder_output_dim': 32 |
| | }, |
| | hidden_dim=hidden_dim, |
| | additional_layer_num=additional_layer_num |
| | ) |
| |
|
| | |
| | self.neural_rendering_resolution = rendering_kwargs.get( |
| | 'resolution', 64) |
| | self.rendering_kwargs = rendering_kwargs |
| |
|
| | def synthesis(self, |
| | z, |
| | c, |
| | seg, |
| | neural_rendering_resolution=None, |
| | update_emas=False, |
| | **synthesis_kwargs): |
| | cam2world_matrix = c[:, :16].view(-1, 4, 4) |
| | if self.rendering_kwargs.get('random_pose', False): |
| | cam2world_matrix = None |
| |
|
| | if neural_rendering_resolution is None: |
| | neural_rendering_resolution = self.neural_rendering_resolution |
| | else: |
| | self.neural_rendering_resolution = neural_rendering_resolution |
| |
|
| | xy_planes = self.backbone(z=z, input=seg) |
| | if self.segmask: |
| | xy_planes = xy_planes * seg[:, 0, ...][:, None, ...] |
| |
|
| | |
| |
|
| | wp = z |
| |
|
| | rendering_result = self.renderer( |
| | wp=wp, |
| | feature_extractor=self.feature_extractor, |
| | rendering_options=self.rendering_kwargs, |
| | cam2world_matrix=cam2world_matrix, |
| | position_encoder=None, |
| | ref_representation=xy_planes, |
| | post_module=self.post_module, |
| | fc_head=self.fc_head) |
| |
|
| | feature_samples = rendering_result['composite_rgb'] |
| | depth_samples = rendering_result['composite_depth'] |
| |
|
| | |
| | N = wp.shape[0] |
| | H = W = self.neural_rendering_resolution |
| | feature_image = feature_samples.permute(0, 2, 1).reshape( |
| | N, feature_samples.shape[-1], H, W).contiguous() |
| | depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
| |
|
| | |
| | |
| | rgb_image = feature_image[:, :3] |
| | if self.interpolate_sr: |
| | sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False) |
| | else: |
| | sr_image = self.post_neural_renderer( |
| | rgb_image, |
| | feature_image, |
| | |
| | noise_mode=self.rendering_kwargs['superresolution_noise_mode'], |
| | **{ |
| | k: synthesis_kwargs[k] |
| | for k in synthesis_kwargs.keys() if k != 'noise_mode' |
| | }) |
| |
|
| | return { |
| | 'image': sr_image, |
| | 'image_raw': rgb_image, |
| | 'image_depth': depth_image |
| | } |
| |
|
| | def sample(self, |
| | coordinates, |
| | directions, |
| | z, |
| | c, |
| | seg, |
| | truncation_psi=1, |
| | truncation_cutoff=None, |
| | update_emas=False, |
| | **synthesis_kwargs): |
| | |
| | |
| | cam2world_matrix = c[:, :16].view(-1, 4, 4) |
| | xy_planes = self.backbone(z=z, input=seg) |
| | wp = z |
| | result = self.renderer.get_sigma_rgb( |
| | wp=wp, |
| | points=coordinates, |
| | feature_extractor=self.feature_extractor, |
| | fc_head=self.fc_head, |
| | rendering_options=self.rendering_kwargs, |
| | ref_representation=xy_planes, |
| | post_module=self.post_module, |
| | ray_dirs=directions, |
| | cam_matrix=cam2world_matrix) |
| |
|
| | return result |
| |
|
| | def sample_mixed(self, |
| | coordinates, |
| | directions, |
| | z, c, seg, |
| | truncation_psi=1, |
| | truncation_cutoff=None, |
| | update_emas=False, |
| | **synthesis_kwargs): |
| | |
| | |
| | cam2world_matrix = c[:, :16].view(-1, 4, 4) |
| | xy_planes = self.backbone(z=z, input=seg) |
| | wp = z |
| | result = self.renderer.get_sigma_rgb( |
| | wp=wp, |
| | points=coordinates, |
| | feature_extractor=self.feature_extractor, |
| | fc_head=self.fc_head, |
| | rendering_options=self.rendering_kwargs, |
| | ref_representation=xy_planes, |
| | post_module=self.post_module, |
| | ray_dirs=directions, |
| | cam_matrix=cam2world_matrix) |
| |
|
| | return result |
| |
|
| | def forward(self, |
| | z, |
| | c, |
| | seg, |
| | c_swapped=None, |
| | style_mixing_prob=0, |
| | truncation_psi=1, |
| | truncation_cutoff=None, |
| | neural_rendering_resolution=None, |
| | update_emas=False, |
| | sample_mixed=False, |
| | coordinates=None, |
| | **synthesis_kwargs): |
| |
|
| | |
| | c_wp = c.clone() |
| | if c_swapped is not None: |
| | c_wp = c_swapped.clone() |
| |
|
| | if not sample_mixed: |
| | gen_output = self.synthesis( |
| | z, |
| | c, |
| | seg, |
| | update_emas=update_emas, |
| | neural_rendering_resolution=neural_rendering_resolution, |
| | **synthesis_kwargs) |
| |
|
| | return { |
| | 'wp': z, |
| | 'gen_output': gen_output, |
| | } |
| |
|
| | else: |
| | |
| | assert coordinates is not None |
| | sample_sigma = self.sample_mixed(coordinates, |
| | torch.randn_like(coordinates), |
| | z, c, seg, |
| | update_emas=False)['sigma'] |
| |
|
| | return { |
| | 'wp': z, |
| | 'sample_sigma': sample_sigma |
| | } |
| |
|
| |
|
| | class OSGDecoder(nn.Module): |
| | """Defines fully-connected layer head in EG3D.""" |
| | def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0): |
| | super().__init__() |
| | self.hidden_dim = hidden_dim |
| |
|
| | lst = [] |
| | lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) |
| | lst.append(nn.Softplus()) |
| | for i in range(additional_layer_num): |
| | lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul'])) |
| | lst.append(nn.Softplus()) |
| | lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])) |
| | self.net = nn.Sequential(*lst) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, point_features, wp=None, dirs=None): |
| | |
| | |
| | |
| |
|
| | N, R, K, C = point_features.shape |
| | x = point_features.reshape(-1, point_features.shape[-1]) |
| | x = self.net(x) |
| | x = x.view(N, -1, x.shape[-1]) |
| |
|
| | |
| | rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 |
| | sigma = x[..., 0:1] |
| |
|
| | return {'rgb': rgb, 'sigma': sigma} |
| |
|