| |
| |
| |
| |
|
|
| |
| import functools |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| from omegaconf import DictConfig |
| from pytorch3d.common.compat import prod |
| from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle |
| from pytorch3d.implicitron.tools.config import ( |
| get_default_args_field, |
| registry, |
| run_auto_creation, |
| ) |
| from pytorch3d.implicitron.tools.utils import evaluating |
|
|
| from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput |
| from .ray_tracing import RayTracing |
| from .rgb_net import RayNormalColoringNetwork |
|
|
|
|
| @registry.register |
| class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): |
| render_features_dimensions: int = 3 |
| object_bounding_sphere: float = 1.0 |
| |
| ray_tracer: RayTracing |
| ray_normal_coloring_network_args: DictConfig = get_default_args_field( |
| RayNormalColoringNetwork |
| ) |
| bg_color: Tuple[float, ...] = (0.0,) |
| soft_mask_alpha: float = 50.0 |
|
|
| def __post_init__( |
| self, |
| ): |
| render_features_dimensions = self.render_features_dimensions |
| if len(self.bg_color) not in [1, render_features_dimensions]: |
| raise ValueError( |
| f"Background color should have {render_features_dimensions} entries." |
| ) |
|
|
| run_auto_creation(self) |
|
|
| self.ray_normal_coloring_network_args["feature_vector_size"] = ( |
| render_features_dimensions |
| ) |
| self._rgb_network = RayNormalColoringNetwork( |
| **self.ray_normal_coloring_network_args |
| ) |
|
|
| self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) |
|
|
| @classmethod |
| def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None: |
| del args["object_bounding_sphere"] |
|
|
| def create_ray_tracer(self) -> None: |
| self.ray_tracer = RayTracing( |
| |
| **self.ray_tracer_args, |
| object_bounding_sphere=self.object_bounding_sphere, |
| ) |
|
|
| def requires_object_mask(self) -> bool: |
| return True |
|
|
| def forward( |
| self, |
| ray_bundle: ImplicitronRayBundle, |
| implicit_functions: List[ImplicitFunctionWrapper], |
| evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, |
| object_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> RendererOutput: |
| """ |
| Args: |
| ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the |
| sampled rendering rays. |
| implicit_functions: single element list of ImplicitFunctionWrappers which |
| defines the implicit function to be used. |
| evaluation_mode: one of EvaluationMode.TRAINING or |
| EvaluationMode.EVALUATION which determines the settings used for |
| rendering. |
| kwargs: |
| object_mask: BoolTensor, denoting the silhouette of the object. |
| This is a required keyword argument for SignedDistanceFunctionRenderer |
| |
| Returns: |
| instance of RendererOutput |
| """ |
| if len(implicit_functions) != 1: |
| raise ValueError( |
| "SignedDistanceFunctionRenderer supports only single pass." |
| ) |
|
|
| if object_mask is None: |
| raise ValueError("Expected object_mask to be provided in the kwargs") |
| object_mask = object_mask.bool() |
|
|
| implicit_function = implicit_functions[0] |
| implicit_function_gradient = functools.partial(_gradient, implicit_function) |
|
|
| |
| batch_size, *spatial_size, _ = ray_bundle.lengths.shape |
| num_pixels = prod(spatial_size) |
|
|
| cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3) |
| ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3) |
| object_mask = object_mask.reshape(batch_size, -1) |
|
|
| with torch.no_grad(), evaluating(implicit_function): |
| points, network_object_mask, dists = self.ray_tracer( |
| sdf=lambda x: implicit_function(rays_points_world=x)[ |
| :, 0 |
| ], |
| cam_loc=cam_loc, |
| object_mask=object_mask, |
| ray_directions=ray_dirs, |
| ) |
|
|
| |
| depth = dists.reshape(batch_size, num_pixels, 1) |
| points = (cam_loc + depth * ray_dirs).reshape(-1, 3) |
|
|
| sdf_output = implicit_function(rays_points_world=points)[:, 0:1] |
| |
| |
| ray_dirs = ray_dirs.reshape(-1, 3) |
| object_mask = object_mask.reshape(-1) |
|
|
| |
| if evaluation_mode == EvaluationMode.TRAINING: |
| surface_mask = network_object_mask & object_mask |
| surface_points = points[surface_mask] |
| surface_dists = dists[surface_mask].unsqueeze(-1) |
| surface_ray_dirs = ray_dirs[surface_mask] |
| surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask] |
| surface_output = sdf_output[surface_mask] |
| N = surface_points.shape[0] |
|
|
| |
| eik_bounding_box: float = self.object_bounding_sphere |
| n_eik_points = batch_size * num_pixels // 2 |
| eikonal_points = torch.empty( |
| n_eik_points, |
| 3, |
| |
| |
| |
| device=self._bg_color.device, |
| ).uniform_(-eik_bounding_box, eik_bounding_box) |
| eikonal_pixel_points = points.clone() |
| eikonal_pixel_points = eikonal_pixel_points.detach() |
| eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0) |
|
|
| points_all = torch.cat([surface_points, eikonal_points], dim=0) |
|
|
| output = implicit_function(rays_points_world=surface_points) |
| surface_sdf_values = output[ |
| :N, 0:1 |
| ].detach() |
|
|
| g = implicit_function_gradient(points_all) |
| surface_points_grad = g[:N, 0, :].clone().detach() |
| grad_theta = g[N:, 0, :] |
|
|
| differentiable_surface_points = _sample_network( |
| surface_output, |
| surface_sdf_values, |
| surface_points_grad, |
| surface_dists, |
| surface_cam_loc, |
| surface_ray_dirs, |
| ) |
|
|
| else: |
| surface_mask = network_object_mask |
| differentiable_surface_points = points[surface_mask] |
| grad_theta = None |
|
|
| empty_render = differentiable_surface_points.shape[0] == 0 |
| features = implicit_function(rays_points_world=differentiable_surface_points)[ |
| None, :, 1: |
| ] |
| normals_full = features.new_zeros( |
| batch_size, *spatial_size, 3, requires_grad=empty_render |
| ) |
| render_full = ( |
| features.new_ones( |
| batch_size, |
| *spatial_size, |
| self.render_features_dimensions, |
| requires_grad=empty_render, |
| ) |
| * self._bg_color |
| ) |
| mask_full = features.new_ones( |
| batch_size, *spatial_size, 1, requires_grad=empty_render |
| ) |
| if not empty_render: |
| normals = implicit_function_gradient(differentiable_surface_points)[ |
| None, :, 0, : |
| ] |
| normals_full.view(-1, 3)[surface_mask] = normals |
| render_full.view(-1, self.render_features_dimensions)[surface_mask] = ( |
| |
| self._rgb_network( |
| features, |
| differentiable_surface_points[None], |
| normals, |
| ray_bundle, |
| surface_mask[None, :, None], |
| pooling_fn=None, |
| ) |
| ) |
| mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid( |
| |
| -self.soft_mask_alpha * sdf_output[~surface_mask] |
| ) |
|
|
| |
| points_full = ray_bundle.origins.detach().clone() |
| points_full.view(-1, 3)[surface_mask] = differentiable_surface_points |
|
|
| |
| return RendererOutput( |
| features=render_full, |
| normals=normals_full, |
| depths=depth.reshape(batch_size, *spatial_size, 1), |
| masks=mask_full, |
| points=points_full, |
| aux={"grad_theta": grad_theta}, |
| |
| ) |
|
|
|
|
| def _sample_network( |
| surface_output, |
| surface_sdf_values, |
| surface_points_grad, |
| surface_dists, |
| surface_cam_loc, |
| surface_ray_dirs, |
| eps: float = 1e-4, |
| ): |
| |
| surface_ray_dirs_0 = surface_ray_dirs.detach() |
| surface_points_dot = torch.bmm( |
| surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1) |
| ).squeeze(-1) |
| dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1 |
| surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / ( |
| surface_points_dot.abs().clip(eps) * dot_sign |
| ) |
|
|
| |
| surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs |
|
|
| return surface_points_theta_c_v |
|
|
|
|
| @torch.enable_grad() |
| def _gradient(module, rays_points_world): |
| rays_points_world.requires_grad_(True) |
| y = module.forward(rays_points_world=rays_points_world)[:, :1] |
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
| gradients = torch.autograd.grad( |
| outputs=y, |
| inputs=rays_points_world, |
| grad_outputs=d_output, |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True, |
| )[0] |
| return gradients.unsqueeze(1) |
|
|