| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import itertools |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes |
| |
|
| |
|
| | class OSGDecoder(nn.Module): |
| | """ |
| | Triplane decoder that gives RGB and sigma values from sampled features. |
| | Using ReLU here instead of Softplus in the original implementation. |
| | |
| | Reference: |
| | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 |
| | """ |
| | def __init__(self, n_features: int, |
| | hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): |
| | super().__init__() |
| |
|
| | self.net_sdf = nn.Sequential( |
| | nn.Linear(3 * n_features, hidden_dim), |
| | activation(), |
| | *itertools.chain(*[[ |
| | nn.Linear(hidden_dim, hidden_dim), |
| | activation(), |
| | ] for _ in range(num_layers - 2)]), |
| | nn.Linear(hidden_dim, 1), |
| | ) |
| | self.net_rgb = nn.Sequential( |
| | nn.Linear(3 * n_features, hidden_dim), |
| | activation(), |
| | *itertools.chain(*[[ |
| | nn.Linear(hidden_dim, hidden_dim), |
| | activation(), |
| | ] for _ in range(num_layers - 2)]), |
| | nn.Linear(hidden_dim, 3), |
| | ) |
| | self.net_deformation = nn.Sequential( |
| | nn.Linear(3 * n_features, hidden_dim), |
| | activation(), |
| | *itertools.chain(*[[ |
| | nn.Linear(hidden_dim, hidden_dim), |
| | activation(), |
| | ] for _ in range(num_layers - 2)]), |
| | nn.Linear(hidden_dim, 3), |
| | ) |
| | self.net_weight = nn.Sequential( |
| | nn.Linear(8 * 3 * n_features, hidden_dim), |
| | activation(), |
| | *itertools.chain(*[[ |
| | nn.Linear(hidden_dim, hidden_dim), |
| | activation(), |
| | ] for _ in range(num_layers - 2)]), |
| | nn.Linear(hidden_dim, 21), |
| | ) |
| |
|
| | |
| | for m in self.modules(): |
| | if isinstance(m, nn.Linear): |
| | nn.init.zeros_(m.bias) |
| |
|
| | def get_geometry_prediction(self, sampled_features, flexicubes_indices): |
| | _N, n_planes, _M, _C = sampled_features.shape |
| | sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
| |
|
| | sdf = self.net_sdf(sampled_features) |
| | deformation = self.net_deformation(sampled_features) |
| |
|
| | grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) |
| | grid_features = grid_features.reshape( |
| | sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) |
| | weight = self.net_weight(grid_features) * 0.1 |
| |
|
| | return sdf, deformation, weight |
| | |
| | def get_texture_prediction(self, sampled_features): |
| | _N, n_planes, _M, _C = sampled_features.shape |
| | sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
| |
|
| | rgb = self.net_rgb(sampled_features) |
| | rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 |
| |
|
| | return rgb |
| |
|
| |
|
| | class TriplaneSynthesizer(nn.Module): |
| | """ |
| | Synthesizer that renders a triplane volume with planes and a camera. |
| | |
| | Reference: |
| | EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 |
| | """ |
| |
|
| | DEFAULT_RENDERING_KWARGS = { |
| | 'ray_start': 'auto', |
| | 'ray_end': 'auto', |
| | 'box_warp': 2., |
| | 'white_back': True, |
| | 'disparity_space_sampling': False, |
| | 'clamp_mode': 'softplus', |
| | 'sampler_bbox_min': -1., |
| | 'sampler_bbox_max': 1., |
| | } |
| |
|
| | def __init__(self, triplane_dim: int, samples_per_ray: int): |
| | super().__init__() |
| |
|
| | |
| | self.triplane_dim = triplane_dim |
| | self.rendering_kwargs = { |
| | **self.DEFAULT_RENDERING_KWARGS, |
| | 'depth_resolution': samples_per_ray // 2, |
| | 'depth_resolution_importance': samples_per_ray // 2, |
| | } |
| |
|
| | |
| | self.plane_axes = generate_planes() |
| | self.decoder = OSGDecoder(n_features=triplane_dim) |
| |
|
| | def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): |
| | plane_axes = self.plane_axes.to(planes.device) |
| | sampled_features = sample_from_planes( |
| | plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) |
| |
|
| | sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) |
| | return sdf, deformation, weight |
| | |
| | def get_texture_prediction(self, planes, sample_coordinates): |
| | plane_axes = self.plane_axes.to(planes.device) |
| | sampled_features = sample_from_planes( |
| | plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) |
| |
|
| | rgb = self.decoder.get_texture_prediction(sampled_features) |
| | return rgb |
| |
|