| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. |
| Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MipRayMarcher2(nn.Module): |
| def __init__(self, activation_factory): |
| super().__init__() |
| self.activation_factory = activation_factory |
|
|
| def run_forward(self, colors, densities, depths, semantics, rendering_options, normals=None): |
| dtype = colors.dtype |
| deltas = depths[:, :, 1:] - depths[:, :, :-1] |
| colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 |
| densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 |
| depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 |
| semantics_mid = (semantics[:, :, :-1] + semantics[:, :, 1:]) / 2 |
|
|
| |
| densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) |
|
|
| density_delta = densities_mid * deltas |
|
|
| alpha = 1 - torch.exp(-density_delta).to(dtype) |
|
|
| for bid in range(len(rendering_options['levels'])): |
| if rendering_options['levels'][bid] == 0: |
| pass |
| elif rendering_options['levels'][bid] == 1: |
| alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1]) |
| semantics_mid[bid, ..., 0:1] = 0 |
| semantics_mid[bid, ..., 1:] = (semantics_mid[bid, ..., 1:] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:], dim=-1, keepdim=True) + 1e-6) |
| elif rendering_options['levels'][bid] == 2: |
| alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1] - semantics_mid[bid, ..., 3:4]) |
| semantics_mid[bid, ..., 0:1] = 0 |
| semantics_mid[bid, ..., 3:4] = 0 |
| semantics_mid[bid, ..., 1:3] = (semantics_mid[bid, ..., 1:3] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:3], dim=-1, keepdim=True) + 1e-6) |
| else: |
| raise NotImplementedError("Only 0, 1, 2 levels are supported") |
|
|
| alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) |
| weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] |
| weights = weights.to(dtype) |
|
|
| composite_rgb = torch.sum(weights * colors_mid, -2) |
| weight_total = weights.sum(2) |
| |
| composite_depth = torch.sum(weights * depths_mid, -2) |
| composite_semantics = torch.sum(weights * semantics_mid, -2) |
|
|
| |
| composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) |
| composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) |
|
|
| if rendering_options.get('white_back', False): |
| composite_rgb = composite_rgb + 1 - weight_total |
|
|
| |
| |
|
|
| return composite_rgb, composite_depth, composite_semantics, weights |
|
|
|
|
| def forward(self, colors, densities, depths, semantics, rendering_options, normals=None): |
| if normals is not None: |
| raise NotImplementedError("Normals are not supported in the ray marcher yet.") |
| composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) |
| return composite_rgb, composite_depth, composite_normals, weights |
|
|
| composite_rgb, composite_depth, composite_semantic, weights = self.run_forward(colors, densities, depths, semantics, rendering_options) |
| return composite_rgb, composite_depth, composite_semantic, weights |
|
|