|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Back projection module.""" |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from ..utils.frustum import ( |
|
|
generate_frustum, |
|
|
generate_frustum_volume, |
|
|
compute_camera2frustum_transform |
|
|
) |
|
|
|
|
|
|
|
|
class BackProjection(nn.Module): |
|
|
"""Back projection module.""" |
|
|
|
|
|
def __init__(self, cfg): |
|
|
"""Initialize the back projection module.""" |
|
|
super().__init__() |
|
|
self.image_size = cfg.dataset.depth_size |
|
|
self.depth_min = cfg.dataset.depth_min |
|
|
self.depth_max = cfg.dataset.depth_max |
|
|
self.voxel_size = cfg.model.projection.voxel_size |
|
|
self.frustum_dimensions = torch.tensor([cfg.model.frustum3d.frustum_dims] * 3) |
|
|
|
|
|
def forward( |
|
|
self, shp, intrinsics, frustum_masks=None, room_masks=None |
|
|
): |
|
|
"""Forward pass.""" |
|
|
device = intrinsics.device |
|
|
if frustum_masks is None: |
|
|
frustum_masks = torch.ones( |
|
|
[len(intrinsics), *self.frustum_dimensions], |
|
|
dtype=torch.bool, device=device |
|
|
) |
|
|
len_shp = len(frustum_masks.shape) |
|
|
if len_shp == 3: |
|
|
frustum_masks = frustum_masks[None] |
|
|
intrinsics = intrinsics[None] |
|
|
|
|
|
kepts, mappings = [], [] |
|
|
for bi, (intrinsic, frustum_mask) in enumerate(zip(intrinsics, frustum_masks)): |
|
|
camera2frustum = compute_camera2frustum_transform( |
|
|
intrinsic.cpu(), self.image_size, self.depth_min, |
|
|
self.depth_max, self.voxel_size |
|
|
).to(device) |
|
|
intrinsic_inverse = torch.inverse(intrinsic) |
|
|
coordinates = torch.nonzero(frustum_mask) |
|
|
grid_coordinates = coordinates.clone() |
|
|
grid_coordinates[:, :2] = 256 - grid_coordinates[:, :2] |
|
|
|
|
|
padding_offsets = self.compute_frustum_padding(intrinsic_inverse) |
|
|
grid_coordinates = grid_coordinates - padding_offsets - torch.tensor([1., 1., 1.], device=device) |
|
|
grid_coordinates = torch.cat([ |
|
|
grid_coordinates, torch.ones(len(grid_coordinates), 1, device=device)], 1 |
|
|
) |
|
|
pointcloud = torch.mm(torch.inverse(camera2frustum), grid_coordinates.t()) |
|
|
depth_pixels = torch.mm(intrinsic, pointcloud) |
|
|
|
|
|
depth = depth_pixels[2] |
|
|
coord = depth_pixels[0:2] / depth |
|
|
coord = torch.cat([coord, coordinates[:, 2][None]], 0).permute(1, 0) |
|
|
|
|
|
kept = (depth <= self.depth_max) * \ |
|
|
(depth >= self.depth_min) * \ |
|
|
(coord[:, 0] < shp[1]) * (coord[:, 1] < shp[0]) |
|
|
coordinates = coordinates[kept] |
|
|
depth = depth[kept, None] |
|
|
|
|
|
mapping = torch.zeros(256, 256, 256, 5, device=depth.device) - 1. |
|
|
mapping[coordinates[:, 0], coordinates[:, 1], coordinates[:, 2]] = \ |
|
|
torch.cat([torch.ones_like(depth) * bi, coord[kept], depth], -1) |
|
|
|
|
|
kept = (mapping >= 0).all(-1) |
|
|
|
|
|
if room_masks is not None: |
|
|
mapping_kept = mapping[kept].long() |
|
|
kept[kept.clone()] = room_masks[bi, 0, mapping_kept[:, 2], mapping_kept[:, 1]] |
|
|
|
|
|
kepts.append(kept) |
|
|
mappings.append(mapping) |
|
|
|
|
|
if len_shp == 3: |
|
|
kepts = kepts[0] |
|
|
mappings = mappings[0][..., 1:] |
|
|
else: |
|
|
kepts = torch.stack(kepts, 0) |
|
|
mappings = torch.stack(mappings, 0) |
|
|
|
|
|
return kepts, mappings |
|
|
|
|
|
def compute_frustum_padding(self, intrinsic_inverse: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute frustum padding.""" |
|
|
frustum = generate_frustum( |
|
|
self.image_size, intrinsic_inverse.cpu(), self.depth_min, self.depth_max |
|
|
) |
|
|
dimensions, _ = generate_frustum_volume(frustum, self.voxel_size) |
|
|
difference = ( |
|
|
self.frustum_dimensions - torch.tensor(dimensions) |
|
|
).float().to(intrinsic_inverse.device) |
|
|
padding_offsets = difference / 2 |
|
|
return padding_offsets |
|
|
|