DVD / utils /d2n /plane_svd.py
haodongli's picture
init-1
4b35c4e
import torch
import torch.nn as nn
import torch.nn.functional as F
class Depth2normal(nn.Module):
def __init__(self,
d_min: float = 0.0,
d_max: float = 10.0,
k: int = 3,
d: int = 1,
min_nghbr: int = 4,
gamma: float = None,
gamma_exception: bool = False):
super(Depth2normal, self).__init__()
# range of valid depth values
# if the depth is outside this range, it will be considered invalid
self.d_min = d_min
self.d_max = d_max
# neighborhood size, k x k neighborhood around each pixel will be considered
self.k = k
# spacing between the nghbrs (dilation)
self.d = d
# if the difference between the center depth and nghbr depth is larger than this, it will be ignored
# e.g. gamma=0.05 means that a nghbr pixel is ignored if its depth is more than 5% different from the center pixel
self.gamma = gamma
# minimum number of nghbr pixels
# if the number of valid nghbr pixels is below this value, the normals would be considered invalid
self.min_nghbr = min_nghbr
# if the normal of a flat surface is near-vertical to the viewing direction, the depth gradient will be very high,
# and most nghbr pixels would not pass the above "gamma" test
# this can be a problem when using datasets like Virtual KITTI (i.e. the ones with large horizontal surfaces)
# if gamma_exception is set to True,
# the "gamma" test will be ignored when the number of valid nghbr pixels < self.min_nghbr
self.gamma_exception = gamma_exception
# padding for depth map
self.pad = (k + (k - 1) * (d - 1)) // 2
# index of the center pixel
self.center_idx = (k*k - 1) // 2
# torch Unfold to unfold the depth map
self.unfold = torch.nn.Unfold(kernel_size=(k, k), padding=self.pad, dilation=d)
def forward(self, points: torch.Tensor) -> torch.Tensor:
""" points: 3D points in camera coordinates, shape: (B, 3, H, W)
"""
b, _, h, w = points.shape
# matrix_a (b, h, w, k*k, 3)
torch_patches = self.unfold(points) # (b, 3*k*k, h, w)
matrix_a = torch_patches.view(b, 3, self.k * self.k, h, w) # (b, 3, k*k, h, w)
matrix_a = matrix_a.permute(0, 3, 4, 2, 1) # (b, h, w, k*k, 3)
# filter by depth
valid_condition = torch.logical_and(points[:,2:,:,:] > self.d_min, points[:,2:,:,:] < self.d_max)
valid_condition = valid_condition.float() # (B, 1, H, W)
valid_condition = self.unfold(valid_condition) # (b, 1*k*k, h, w)
valid_condition = valid_condition.view(b, 1, self.k * self.k, h, w) # (b, 1, k*k, h, w)
valid_condition = valid_condition.permute(0, 3, 4, 2, 1) # (b, h, w, k*k, 1)
# valid_condition (b, h, w, k*k, 1)
if self.gamma is not None:
valid_depth_diff = torch.abs(matrix_a[:, :, :, :, 2:] - matrix_a[:, :, :, self.center_idx:self.center_idx+1, 2:]) \
/ matrix_a[:, :, :, self.center_idx:self.center_idx+1, 2:]
valid_depth_diff = (valid_depth_diff < self.gamma).float() # (b, h, w, k*k, 1)
if self.gamma_exception:
valid_depth_diff_sum = torch.sum(valid_depth_diff, dim=3, keepdim=True) # (b, h, w, 1, 1)
valid_depth_diff_sum = (valid_depth_diff_sum < self.min_nghbr).float() # (b, h, w, 1, 1)
valid_depth_diff = valid_depth_diff + valid_depth_diff_sum
valid_depth_diff = (valid_depth_diff > 0.5).float()
valid_condition = valid_condition * valid_depth_diff
# matrix A (b, h, w, k*k, 4)
matrix_1 = torch.ones_like(matrix_a[:,:,:,:,0:1])
matrix_A = torch.cat([matrix_a, matrix_1], dim=-1)
# fill zero for invalid pixels
matrix_A_zero = torch.zeros_like(matrix_A)
matrix_A = torch.where(valid_condition.repeat([1, 1, 1, 1, 4]) > 0.5, matrix_A, matrix_A_zero)
# transpose
matrix_At = torch.transpose(matrix_A, 3, 4)
matrix_A = matrix_A.view(-1, self.k * self.k, 4) # (b*h*w, k*k, 4)
matrix_At = matrix_At.view(-1, 4, self.k * self.k) # (b*h*w, 4, k*k)
At_A = torch.bmm(matrix_At, matrix_A) # (b*h*w, 4, 4)
# eig_val: (b*h*w, 4) / eig_vec: (b*h*w, 4, 4)
eig_val, eig_vec = torch.linalg.eig(At_A)
# valid_mask (b*h*w)
valid_eig = torch.logical_and(torch.sum(eig_val.imag, dim=1) == 0,
torch.sum(eig_vec.imag, dim=(1, 2)) == 0)
# find the smallest eigenvalue
eig_val = eig_val.real
eig_vec = eig_vec.real
idx = torch.argmin(eig_val, dim=1, keepdim=True) # (b*h*w, 1)
idx_onehot = torch.zeros_like(eig_val).scatter_(1, idx, 1.) # (b*h*w, 4)
idx_onehot = idx_onehot.unsqueeze(1).repeat(1, 4, 1)
# normal (b, 3, h, w)
normal = torch.sum(eig_vec * idx_onehot, dim=2)
normal = normal.view(b, h, w, 4).permute(0, 3, 1, 2).contiguous()
normal = F.normalize(normal[:,:3,:,:], p=2.0, dim=1, eps=1e-12)
# flip if needed
flip = torch.sign(torch.sum(normal * points, dim=1, keepdim=True))
normal = normal * flip
# valid_mask1 (b, 1, h, w): center pixel valid depth
valid_mask1 = valid_condition[:,:,:,self.center_idx,0].unsqueeze(1)
# valid_mask2 (b, 1, h, w): sufficient number of valid neighbors
valid_mask2 = torch.sum(valid_condition[..., 0], dim=3).unsqueeze(1) >= self.min_nghbr
# valid_mask3 (b, 1, h, w): eigenvalue real
valid_mask3 = valid_eig.view(b, h, w).unsqueeze(1)
# valid_mask4 (b, 1, h, w):
valid_mask4 = torch.norm(normal, p=2, dim=1, keepdim=True) > 0.5
# valid_mask
valid_mask = valid_mask1 * valid_mask2 * valid_mask3 * valid_mask4
return normal, valid_mask > 0.5