# python3.8 """Contains the functions to represent a point in 3D space. Typically, a point can be represented by its 3D coordinates, by retrieving from a feature volume, or by combining triplane features. Paper (coordinate): https://arxiv.org/pdf/2003.08934.pdf Paper (feature volume): https://arxiv.org/pdf/2112.10759.pdf Paper (triplane): https://arxiv.org/pdf/2112.07945.pdf """ from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F __all__ = ['PointRepresenter'] _REPRESENTATION_TYPES = ['coordinate', 'volume', 'triplane', 'hybrid', 'mpi', 'oneplane', 'oneplane_multi'] class PointRepresenter(nn.Module): """Defines the class to get per-point representation. This class implements the `forward()` function to get the representation based on the per-point 3D coordinates and the reference representation (such as a feature volume or triplane features). """ def __init__(self, representation_type='coordinate', triplane_axes=None, mpi_levels=None, coordinate_scale=None, bound=None, return_eikonal=False, ): """Initializes hyper-parameters for getting point representations. NOTE: When using triplane representation, the three planes are defaulted as follows: [ [[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]] ] where for each plane, the first two rows stand for the plane axes while the third row stands for the plane normal. Args: representation_type: Type of representation used to describe a point in the 3D space. Defaults to `coordinate`. coordinate_scale: Scale factor to normalize coordinates. Defaults to `None`. bound: Bound used to normalize coordinates, with shape [1, 2, 3]. Defaults to `None`. return_eikonal: If the eikonal loss is to be used, we utilize the function `grid_sample_customized()` instead of `F.grid_sample()` to avoid errors in computing the second derivative. Note that only one of the above two parameters used for normalizing coordinates can be available. """ super().__init__() self.coordinate_scale = None if (coordinate_scale is not None) and (coordinate_scale > 0): self.coordinate_scale = coordinate_scale if bound is not None: self.register_buffer('bound', bound) else: self.bound = None self.return_eikonal = return_eikonal representation_type = representation_type.lower() if representation_type not in _REPRESENTATION_TYPES: raise ValueError(f'Invalid representation type: ' f'`{representation_type}`!\n' f'Types allowed: {_REPRESENTATION_TYPES}.') self.representation_type = representation_type if self.representation_type in ['coordinate', 'volume']: pass elif self.representation_type in ['triplane', 'hybrid']: if triplane_axes is None: self.register_buffer( 'triplane_axes', torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]], [[0, 0, 1], [0, 1, 0], [1, 0, 0]]], dtype=torch.float32)) else: self.register_buffer('triplane_axes', triplane_axes) elif self.representation_type in ['oneplane', 'oneplane_multi']: self.register_buffer( 'oneplane_axes', torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32)) elif self.representation_type == 'mpi': self.register_buffer('mpi_levels', mpi_levels) else: raise NotImplementedError(f'Not implemented representation type: ' f'`{self.representation_type}`!\n') def forward(self, points, ref_representation=None, align_corners=False): """Gets per-point representation based on its coordinates. For simplicity, we define the following notations: `N` denotes batch size. `R` denotes the number of rays, which usually equals `H * W`. `K` denotes the number of points on each ray. `C` denotes the dimension of per-point representation. Args: points: Per-point 3D coordinates, with shape [N, R * K, 3]. ref_representation: The reference representation, depending on the representation type used. For example, this field will be ignored if `self.representation_type` is set as `coordinate`, a feature volume is expected if `self.representation_type` is set as `volume`, while triplane features are expected if `self.representation_type` is set as `triplane`. Defaults to `None`. Returns: Per-point representation, with shape [N, R * K, C]. """ if self.representation_type == 'coordinate': return points if self.representation_type == 'mpi': return retrieve_from_mpi(points=points, # [N, R, K, 3] isosurfaces=ref_representation, levels=self.mpi_levels) # Normalize point coordinates to the desired range, typically [-1, 1]. if self.coordinate_scale is not None: normalized_points = (2 / self.coordinate_scale) * points elif self.bound is not None: normalized_points = (points - self.bound[:, :1]) / ( self.bound[:, 1:] - self.bound[:, :1]) # To range [0, 1]. normalized_points = 2 * normalized_points - 1 # To range [-1, 1]. else: normalized_points = points if self.representation_type == 'volume': return retrieve_from_volume( coordinates=normalized_points, volume=ref_representation) if self.representation_type == 'triplane': return retrieve_from_planes( plane_axes=self.triplane_axes.to(points.device), plane_features=ref_representation, coordinates=normalized_points, align_corners=align_corners, return_eikonal=self.return_eikonal, ) if self.representation_type == 'oneplane': return retrieve_from_one_plane( plane_axes=self.oneplane_axes.to(points.device), plane_features=ref_representation, coordinates=normalized_points, align_corners=align_corners, mean=False ) if self.representation_type == 'oneplane_multi': return retrieve_from_one_plane( plane_axes=self.oneplane_axes.to(points.device), plane_features=ref_representation, coordinates=normalized_points, align_corners=align_corners, mean=True ) if self.representation_type == 'hybrid': assert (isinstance(ref_representation, list) or isinstance(ref_representation, tuple)) triplane = ref_representation[0] feature_volume = ref_representation[1] point_features_triplane = retrieve_from_planes( plane_axes=self.triplane_axes.to(points.device), plane_features=triplane, coordinates=normalized_points, align_corners=align_corners, return_eikonal=self.return_eikonal) point_features_volume = retrieve_from_volume( coordinates=normalized_points, volume=feature_volume) point_features = torch.cat( [point_features_volume, point_features_triplane], dim=-1) return point_features raise NotImplementedError(f'Not implemented representation type: ' f'`{self.representation_type}`!\n') def grid_sample_3d(volume, coordinates): """Performs grid sample in 3D space. Given 3D point coordinates, sample values from the volume. Note that this function is similar to function `torch.nn.functional.grid_sample()` in the case of 5-D inputs. Args: volume: The given volume, with shape [N, C, D, H, W]. coordinates: Input 3D point coordinates, with shape [N, 1, 1, d * h * w, 3]. Returns: sampled_vals: Sampled values, with shape [N, C, d * h * w, 1, 1]. """ N, C, ID, IH, IW = volume.shape _, D, H, W, _ = coordinates.shape ix = coordinates[..., 0] iy = coordinates[..., 1] iz = coordinates[..., 2] ix = ((ix + 1) / 2) * (IW - 1) iy = ((iy + 1) / 2) * (IH - 1) iz = ((iz + 1) / 2) * (ID - 1) with torch.no_grad(): ix_tnw = torch.floor(ix) iy_tnw = torch.floor(iy) iz_tnw = torch.floor(iz) ix_tne = ix_tnw + 1 iy_tne = iy_tnw iz_tne = iz_tnw ix_tsw = ix_tnw iy_tsw = iy_tnw + 1 iz_tsw = iz_tnw ix_tse = ix_tnw + 1 iy_tse = iy_tnw + 1 iz_tse = iz_tnw ix_bnw = ix_tnw iy_bnw = iy_tnw iz_bnw = iz_tnw + 1 ix_bne = ix_tnw + 1 iy_bne = iy_tnw iz_bne = iz_tnw + 1 ix_bsw = ix_tnw iy_bsw = iy_tnw + 1 iz_bsw = iz_tnw + 1 ix_bse = ix_tnw + 1 iy_bse = iy_tnw + 1 iz_bse = iz_tnw + 1 tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz) tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz) tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz) tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz) bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse) bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw) bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne) bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw) with torch.no_grad(): torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw) torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw) torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw) torch.clamp(ix_tne, 0, IW - 1, out=ix_tne) torch.clamp(iy_tne, 0, IH - 1, out=iy_tne) torch.clamp(iz_tne, 0, ID - 1, out=iz_tne) torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw) torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw) torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw) torch.clamp(ix_tse, 0, IW - 1, out=ix_tse) torch.clamp(iy_tse, 0, IH - 1, out=iy_tse) torch.clamp(iz_tse, 0, ID - 1, out=iz_tse) torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) volume = volume.view(N, C, ID * IH * IW) tnw_val = torch.gather(volume, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tne_val = torch.gather(volume, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1)) tsw_val = torch.gather(volume, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) tse_val = torch.gather(volume, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1)) bnw_val = torch.gather(volume, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bne_val = torch.gather(volume, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) bsw_val = torch.gather(volume, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bse_val = torch.gather(volume, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) sampled_vals = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) + tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) + tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) + tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W)) return sampled_vals def grid_sample_customized(input, grid): """Customized `grid_sample()` operation. Since the original PyTorch `grid_sample()` operator does not support second derivative computation during the backward pass, we customize this operator. Args: input: Input tensor. grid: Flow-field. Returns: output: Output Tensor. """ N, C, IH, IW = input.shape _, H, W, _ = grid.shape if torch.any(torch.isnan(grid)): grid = torch.ones_like(grid) print('nan') ix = grid[..., 0] iy = grid[..., 1] ix = ((ix + 1) / 2) * (IW - 1) iy = ((iy + 1) / 2) * (IH - 1) with torch.no_grad(): ix_nw = torch.floor(ix) iy_nw = torch.floor(iy) ix_ne = ix_nw + 1 iy_ne = iy_nw ix_sw = ix_nw iy_sw = iy_nw + 1 ix_se = ix_nw + 1 iy_se = iy_nw + 1 nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) with torch.no_grad(): torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) torch.clamp(ix_se, 0, IW - 1, out=ix_se) torch.clamp(iy_se, 0, IH - 1, out=iy_se) input = input.view(N, C, IH * IW) nw_val = torch.gather(input, 2, (iy_nw * IW + ix_nw).long().view( N, 1, H * W).repeat(1, C, 1)) ne_val = torch.gather(input, 2, (iy_ne * IW + ix_ne).long().view( N, 1, H * W).repeat(1, C, 1)) sw_val = torch.gather(input, 2, (iy_sw * IW + ix_sw).long().view( N, 1, H * W).repeat(1, C, 1)) se_val = torch.gather(input, 2, (iy_se * IW + ix_se).long().view( N, 1, H * W).repeat(1, C, 1)) output = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) return output def retrieve_from_volume(coordinates, volume): """Samples point features from feature volume. Args: coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. volume: Feature volume, with shape [N, C, D, H, W]. Returns: output_features: Output sampled point features, with shape [N, R * K, C]. """ grid_coords = coordinates[:, None, None] # [N, 1, 1, R * K, 3] output_features = grid_sample_3d(volume, grid_coords) # [N, C, R * K, 1, 1] output_features = output_features[:, :, 0, 0] # [N, C, R * K] output_features = output_features.permute(0, 2, 1) # [N, R * K, C] return output_features def project_points_onto_planes(points, planes): """ Projects 3D points onto a batch of 2D planes. To project a 3D point `P` onto a 2D plane defined by a normal vector `n` and a point `Q` that lies on the plane, one can use the following formula: P_proj = P - dot(P-Q, n) * n / dot(n, n) where: `P_proj` is the projected point on the plane; `dot()` is the dot product. And `Q` can be chosen as the origin (0, 0, 0) of the coordinate system. Meanwhile, if n` is a normalized vector, then the projection formula is simplified as: P_proj = P - dot(P, n) * n Args: points: Point coordinates, with shape [N, M, 3], where `M` is the number of points in each batch and equals `R * K`. planes: Planes, with shape [n_planes, 3, 3], where `n_planes` is the number of planes. Here, a plane is represented by two vector axes and one normal vector. For instance, if a plane is represented by: `[[0, 0, 1], [0, 1, 0], [1, 0, 0]]`, which means that its axes are the third and second axes of the coordinate system, and its normal vector is `[1, 0, 0]`. Returns: projections: Projections, with shape [N * n_planes, R * K, 2]. """ plane_normals = planes[:, 2] N, M, _ = points.shape # `M` equals `R * K`. n_planes, _ = plane_normals.shape # Normalize the normals to unit vectors. plane_normals = F.normalize(plane_normals, dim=1) # Unsqueeze, expand and reshape tensors. points = points.unsqueeze(1).expand( -1, n_planes, -1, -1).reshape(N * n_planes, M, 3) # [N * n_planes, R * K , 3] plane_normals = plane_normals.unsqueeze(0).expand(N, -1, -1).reshape( N * n_planes, 3) # [N * n_planes, 3] plane_normals = plane_normals.unsqueeze(1).expand( -1, M, -1) # [N * n_planes, R * K, 3] # Compute the projections. projections = points - torch.sum(points * plane_normals, dim=-1).unsqueeze(-1) * plane_normals # Extract the projection values from different planes. plane_axes = planes.unsqueeze(0).expand(N, -1, -1, -1).reshape( N * n_planes, 3, 3) projections = torch.bmm(projections, plane_axes.permute(0, 2, 1))[..., :2] return projections def retrieve_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', align_corners=False, return_eikonal=False, ): """Samples point features from triplane. Borrowed from https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py Args: plane_axes: Axes of triplane, with shape [n_planes, 3, 3]. plane_features: Triplane features, with shape [N, n_planes, C, H, W]. coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. mode: Interpolation mode. Returns: output_features: Output sampled point features, with shape [N, R * K, C]. """ N, n_planes, C, H, W = plane_features.shape _, M, _ = coordinates.shape # `M` equals `R * K`. # plane_features = plane_features.view(N * n_planes, C, H, W) plane_features = rearrange(plane_features, 'N n_planes c h w -> (N n_planes) c h w') projected_coordinates = project_points_onto_planes( coordinates, plane_axes).unsqueeze(1) # [N * n_planes, 1, R * K, 2] if return_eikonal: output_features = grid_sample_customized( plane_features, projected_coordinates.float()) # [N * n_planes, C, 1, R * K] else: output_features = F.grid_sample( plane_features, projected_coordinates.float(), mode=mode, padding_mode='zeros', align_corners=align_corners) # [N * n_planes, C, 1, R * K] output_features = output_features.permute( 0, 3, 2, 1) # [N * n_planes, R * K, 1, C] output_features = output_features.reshape(N, n_planes, M, C) # [N, 3, R * K, C] output_features = output_features.mean(1) # [N, R * K, C] return output_features def retrieve_from_one_plane(plane_axes, plane_features, coordinates, mode='bilinear', align_corners=False, return_eikonal=False, mean = False, ): """Samples point features from triplane. Borrowed from https://github.com/NVlabs/eg3d/blob/main/eg3d/training/volumetric_rendering/renderer.py Args: plane_axes: Axes of triplane, with shape [n_planes, 3, 3]. plane_features: Triplane features, with shape [N, n_planes, C, H, W]. coordinates: Coordinate of input 3D points, with shape [N, R * K, 3]. mode: Interpolation mode. Returns: output_features: Output sampled point features, with shape [N, R * K, C]. """ assert type(plane_features) == list N, num_plane, C, H, W = plane_features[0].shape _, M, _ = coordinates.shape # `M` equals `R * K`. one_plane_features = plane_features[0].view(N * 1, C, H, W) line_features = plane_features[1] # [N, C, L] line_features_4d = line_features.unsqueeze(-1) # [N, C, L, 1] z_point = coordinates[..., -1:] # [N, R * K, 1] z_point = z_point.unsqueeze(1) # [N, 1, R * K, 1] y_fixed = torch.zeros_like(z_point) coordinates_z_cat = torch.cat([y_fixed,z_point], dim=-1) # [8, 1, 786432, 2] # coordinates_z_cat = torch.cat([y_fixed,z_point], dim=-1) # confused [8, 1, 786432, 2] z_features = F.grid_sample( line_features_4d, coordinates_z_cat.float(), mode=mode, padding_mode='zeros', align_corners=align_corners) # [N, C, 1, R * K] z_features = z_features.permute(0, 2, 3, 1) # [N, R * K, 1, C] # z_features = z_features.reshape(N, 1, M, C) # [N, 1, R * K, C] projected_coordinates = project_points_onto_planes( coordinates, plane_axes).unsqueeze(1) # [N * n_planes, 1, R * K, 2] # the last channel exchange the order of the coordinates if return_eikonal: output_features = grid_sample_customized( one_plane_features, projected_coordinates.float()) # [N * n_planes, C, 1, R * K] else: output_features = F.grid_sample( one_plane_features, projected_coordinates.float(), mode=mode, padding_mode='zeros', align_corners=align_corners) # [N * n_planes, C, 1, R * K] output_features = output_features.permute( 0, 3, 2, 1) # [N * n_planes, R * K, 1, C] output_features = output_features.reshape(N,1, M, C) # [N, 1, R * K, C] if mean ==False: output_features = torch.cat([output_features,z_features], dim=-1) else: output_features = torch.cat([output_features,z_features], dim=1) return output_features.mean(1) def retrieve_from_mpi(points, isosurfaces, levels): """Get intersections between camera rays and levels. Args: points : Coordinate of input 3D points, with shape [N, R, K, 3]. isosurfaces : Isosurface scalars predicted by MPIPredictor. levels: Predefined level set values. Returns: intersections: The intersections between camera rays and the levels, with shape [N, R, num_levels - 1, 3] is_valid: Whether a level is valid or not, boolean tensor with shape [N, R, num_levels - 1, 1] """ s_l = isosurfaces[:, :, :-1] s_h = isosurfaces[:, :, 1:] K = points.shape[2] cost = torch.linspace(K - 1, 0, K - 1).float() cost = cost.to(points.device).reshape(1, 1, -1, 1) x_interval = [] s_interval = [] for l in levels: r = (s_h - l <= 0) * (l - s_l <= 0) * 2 - 1 r = r * cost _, indices = torch.max(r, dim=-2, keepdim=True) x_l_select = torch.gather(points, -2, indices.expand(-1, -1, -1, 3)) x_h_select = torch.gather(points, -2, indices.expand(-1, -1, -1, 3) + 1) s_l_select = torch.gather(s_l, -2, indices) s_h_select = torch.gather(s_h, -2, indices) x_interval.append(torch.cat([x_l_select, x_h_select], dim=-2)) s_interval.append(torch.cat([s_l_select, s_h_select], dim=-2)) intersections = [] is_valid = [] for interval, val, l in zip(x_interval, s_interval, levels): x_l = interval[:, :, 0] x_h = interval[:, :, 1] s_l = val[:, :, 0] s_h = val[:, :, 1] scale = torch.where( torch.abs(s_h - s_l) > 0.05, s_h - s_l, torch.ones_like(s_h) * 0.05) intersect = torch.where( ((s_h - l <= 0) * (l - s_l <= 0)) & (torch.abs(s_h - s_l) > 0.05), ((s_h - l) * x_l + (l - s_l) * x_h) / scale, x_h) intersections.append(intersect) is_valid.append(((s_h - l <= 0) * (l - s_l <= 0)).to(intersect.dtype)) intersections = torch.stack(intersections, dim=-2) is_valid = torch.stack(is_valid, dim=-2) return intersections, is_valid