Spaces:
Sleeping
Sleeping
| # Copyright Niantic 2019. Patent Pending. All rights reserved. | |
| # | |
| # This software is licensed under the terms of the Monodepth2 licence | |
| # which allows for non-commercial use only, the full terms of which are made | |
| # available in the LICENSE file. | |
| from __future__ import absolute_import, division, print_function | |
| import numpy as np | |
| from scipy.spatial.transform import Rotation as R | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from torchmetrics.image.fid import FrechetInceptionDistance | |
| # def silog(real1, fake1): | |
| # # filter out invalid pixels | |
| # real = real1.clone() | |
| # fake = fake1.clone() | |
| # N = (real>0).float().sum() | |
| # mask1 = (real<=0) | |
| # mask2 = (fake<=0) | |
| # mask3 = mask1+mask2 | |
| # # mask = 1.0 - (mask3>0).float() | |
| # mask = (mask3>0) | |
| # fake[mask] = 1. | |
| # real[mask] = 1. | |
| # loss_ = torch.log(real)-torch.log(fake) | |
| # loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) | |
| # return loss | |
| class SpatialTransformer(nn.Module): | |
| def __init__(self, size, mode='bilinear'): | |
| """ | |
| Instiantiate the block | |
| :param size: size of input to the spatial transformer block | |
| :param mode: method of interpolation for grid_sampler | |
| """ | |
| super(SpatialTransformer, self).__init__() | |
| # Create sampling grid | |
| vectors = [torch.arange(0, s) for s in size] | |
| grids = torch.meshgrid(vectors) | |
| grid = torch.stack(grids) # y, x, z | |
| grid = torch.unsqueeze(grid, 0) # add batch | |
| grid = grid.type(torch.FloatTensor) | |
| self.register_buffer('grid', grid) | |
| self.mode = mode | |
| def forward(self, src, flow): | |
| """ | |
| Push the src and flow through the spatial transform block | |
| :param src: the source image | |
| :param flow: the output from the U-Net | |
| """ | |
| new_locs = self.grid + flow | |
| shape = flow.shape[2:] | |
| # Need to normalize grid values to [-1, 1] for resampler | |
| for i in range(len(shape)): | |
| new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) | |
| if len(shape) == 2: | |
| new_locs = new_locs.permute(0, 2, 3, 1) | |
| new_locs = new_locs[..., [1, 0]] | |
| elif len(shape) == 3: | |
| new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
| new_locs = new_locs[..., [2, 1, 0]] | |
| return F.grid_sample(src, new_locs, mode=self.mode, padding_mode="border") | |
| class optical_flow(nn.Module): | |
| def __init__(self, size, batch_size, height, width, eps=1e-7): | |
| super(optical_flow, self).__init__() | |
| # Create sampling grid | |
| vectors = [torch.arange(0, s) for s in size] | |
| grids = torch.meshgrid(vectors) | |
| grid = torch.stack(grids) # y, x, z | |
| grid = torch.unsqueeze(grid, 0) # add batch | |
| grid = grid.type(torch.FloatTensor) | |
| self.register_buffer('grid', grid) | |
| self.batch_size = batch_size | |
| self.height = height | |
| self.width = width | |
| self.eps = eps | |
| def forward(self, points, K, T): | |
| P = torch.matmul(K, T)[:, :3, :] | |
| cam_points = torch.matmul(P, points) | |
| pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) | |
| pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) | |
| optical_flow = pix_coords[:, [1,0], ...] - self.grid | |
| return optical_flow | |
| def get_corresponding_map(data): | |
| """ | |
| :param data: unnormalized coordinates Bx2xHxW | |
| :return: Bx1xHxW | |
| """ | |
| B, _, H, W = data.size() | |
| # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W) | |
| # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1) | |
| x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W) | |
| y = data[:, 1, :, :].view(B, -1) | |
| # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN | |
| # invalid = invalid.repeat([1, 4]) | |
| x1 = torch.floor(x) | |
| x_floor = x1.clamp(0, W - 1) | |
| y1 = torch.floor(y) | |
| y_floor = y1.clamp(0, H - 1) | |
| x0 = x1 + 1 | |
| x_ceil = x0.clamp(0, W - 1) | |
| y0 = y1 + 1 | |
| y_ceil = y0.clamp(0, H - 1) | |
| x_ceil_out = x0 != x_ceil | |
| y_ceil_out = y0 != y_ceil | |
| x_floor_out = x1 != x_floor | |
| y_floor_out = y1 != y_floor | |
| invalid = torch.cat([x_ceil_out | y_ceil_out, | |
| x_ceil_out | y_floor_out, | |
| x_floor_out | y_ceil_out, | |
| x_floor_out | y_floor_out], dim=1) | |
| # encode coordinates, since the scatter function can only index along one axis | |
| corresponding_map = torch.zeros(B, H * W).type_as(data) | |
| indices = torch.cat([x_ceil + y_ceil * W, | |
| x_ceil + y_floor * W, | |
| x_floor + y_ceil * W, | |
| x_floor + y_floor * W], 1).long() # BxN (N=4*H*W) | |
| values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)), | |
| (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)), | |
| (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)), | |
| (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))], | |
| 1) | |
| # values = torch.ones_like(values) | |
| values[invalid] = 0 | |
| corresponding_map.scatter_add_(1, indices, values) | |
| # decode coordinates | |
| corresponding_map = corresponding_map.view(B, H, W) | |
| return corresponding_map.unsqueeze(1) | |
| class get_occu_mask_backward(nn.Module): | |
| def __init__(self, size): | |
| super(get_occu_mask_backward, self).__init__() | |
| # Create sampling grid | |
| vectors = [torch.arange(0, s) for s in size] | |
| grids = torch.meshgrid(vectors) | |
| grid = torch.stack(grids) # y, x, z | |
| grid = torch.unsqueeze(grid, 0) # add batch | |
| grid = grid.type(torch.FloatTensor) | |
| self.register_buffer('grid', grid) | |
| def forward(self, flow, th=0.95): | |
| new_locs = self.grid + flow | |
| new_locs = new_locs[:, [1,0], ...] | |
| corr_map = get_corresponding_map(new_locs) | |
| occu_map = corr_map | |
| occu_mask = (occu_map > th).float() | |
| return occu_mask, occu_map | |
| class get_occu_mask_bidirection(nn.Module): | |
| def __init__(self, size, mode='bilinear'): | |
| super(get_occu_mask_bidirection, self).__init__() | |
| # Create sampling grid | |
| vectors = [torch.arange(0, s) for s in size] | |
| grids = torch.meshgrid(vectors) | |
| grid = torch.stack(grids) # y, x, z | |
| grid = torch.unsqueeze(grid, 0) # add batch | |
| grid = grid.type(torch.FloatTensor) | |
| self.register_buffer('grid', grid) | |
| self.mode = mode | |
| def forward(self, flow12, flow21, scale=0.01, bias=0.5): | |
| new_locs = self.grid + flow12 | |
| shape = flow12.shape[2:] | |
| # Need to normalize grid values to [-1, 1] for resampler | |
| for i in range(len(shape)): | |
| new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5) | |
| if len(shape) == 2: | |
| new_locs = new_locs.permute(0, 2, 3, 1) | |
| new_locs = new_locs[..., [1, 0]] | |
| elif len(shape) == 3: | |
| new_locs = new_locs.permute(0, 2, 3, 4, 1) | |
| new_locs = new_locs[..., [2, 1, 0]] | |
| flow21_warped = F.grid_sample(flow21, new_locs, mode=self.mode, padding_mode="border") | |
| flow12_diff = torch.abs(flow12 + flow21_warped) | |
| # mag = (flow12 * flow12).sum(1, keepdim=True) + \ | |
| # (flow21_warped * flow21_warped).sum(1, keepdim=True) | |
| # occ_thresh = scale * mag + bias | |
| # occ_mask = (flow12_diff * flow12_diff).sum(1, keepdim=True) < occ_thresh | |
| return flow12_diff | |
| # functions | |
| def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Return the rotation matrices for one of the rotations about an axis | |
| of which Euler angles describe, for each value of the angle given. | |
| Args: | |
| axis: Axis label "X" or "Y or "Z". | |
| angle: any shape tensor of Euler angles in radians | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| cos = torch.cos(angle) | |
| sin = torch.sin(angle) | |
| one = torch.ones_like(angle) | |
| zero = torch.zeros_like(angle) | |
| if axis == "X": | |
| R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
| elif axis == "Y": | |
| R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
| elif axis == "Z": | |
| R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
| else: | |
| raise ValueError("letter must be either X, Y or Z.") | |
| return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
| def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: | |
| """ | |
| Convert rotations given as Euler angles in radians to rotation matrices. | |
| Args: | |
| euler_angles: Euler angles in radians as tensor of shape (..., 3). | |
| convention: Convention string of three uppercase letters from | |
| {"X", "Y", and "Z"}. | |
| Returns: | |
| Rotation matrices as tensor of shape (..., 3, 3). | |
| """ | |
| if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
| raise ValueError("Invalid input euler angles.") | |
| if len(convention) != 3: | |
| raise ValueError("Convention must have 3 letters.") | |
| if convention[1] in (convention[0], convention[2]): | |
| raise ValueError(f"Invalid convention {convention}.") | |
| for letter in convention: | |
| if letter not in ("X", "Y", "Z"): | |
| raise ValueError(f"Invalid letter {letter} in convention string.") | |
| matrices = [ | |
| _axis_angle_rotation(c, e) | |
| for c, e in zip(convention, torch.unbind(euler_angles, -1)) | |
| ] | |
| # return functools.reduce(torch.matmul, matrices) | |
| rotation_matrices = torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) | |
| rot = torch.zeros((rotation_matrices.shape[0], 4, 4)).to(device=rotation_matrices.device) | |
| rot[:, :3, :3] = rotation_matrices.squeeze() | |
| rot[:, 3, 3] = 1 | |
| return rot | |
| def _angle_from_tan( | |
| axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool | |
| ) -> torch.Tensor: | |
| """ | |
| Extract the first or third Euler angle from the two members of | |
| the matrix which are positive constant times its sine and cosine. | |
| Args: | |
| axis: Axis label "X" or "Y or "Z" for the angle we are finding. | |
| other_axis: Axis label "X" or "Y or "Z" for the middle axis in the | |
| convention. | |
| data: Rotation matrices as tensor of shape (..., 3, 3). | |
| horizontal: Whether we are looking for the angle for the third axis, | |
| which means the relevant entries are in the same row of the | |
| rotation matrix. If not, they are in the same column. | |
| tait_bryan: Whether the first and third axes in the convention differ. | |
| Returns: | |
| Euler Angles in radians for each matrix in data as a tensor | |
| of shape (...). | |
| """ | |
| i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] | |
| if horizontal: | |
| i2, i1 = i1, i2 | |
| even = (axis + other_axis) in ["XY", "YZ", "ZX"] | |
| if horizontal == even: | |
| return torch.atan2(data[..., i1], data[..., i2]) | |
| if tait_bryan: | |
| return torch.atan2(-data[..., i2], data[..., i1]) | |
| return torch.atan2(data[..., i2], -data[..., i1]) | |
| def matrix_2_euler_vector(matrix, convention = 'ZYX', roll = True): | |
| # matrix = matrix_in.copy() | |
| euler = (matrix_to_euler_angles(matrix[:, :3,:3], convention)) # to match with scipy euler = -euler and transpose of this | |
| if roll: | |
| euler[0] = 0.0 | |
| t = matrix[:, :3,3] | |
| out = torch.cat([euler, t], dim = 0) | |
| return out | |
| def _index_from_letter(letter: str) -> int: | |
| if letter == "X": | |
| return 0 | |
| if letter == "Y": | |
| return 1 | |
| if letter == "Z": | |
| return 2 | |
| raise ValueError("letter must be either X, Y or Z.") | |
| def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: | |
| """ | |
| Convert rotations given as rotation matrices to Euler angles in radians. | |
| Args: | |
| matrix: Rotation matrices as tensor of shape (..., 3, 3). | |
| convention: Convention string of three uppercase letters. | |
| Returns: | |
| Euler angles in radians as tensor of shape (..., 3). | |
| """ | |
| if len(convention) != 3: | |
| raise ValueError("Convention must have 3 letters.") | |
| if convention[1] in (convention[0], convention[2]): | |
| raise ValueError(f"Invalid convention {convention}.") | |
| for letter in convention: | |
| if letter not in ("X", "Y", "Z"): | |
| raise ValueError(f"Invalid letter {letter} in convention string.") | |
| if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
| raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
| i0 = _index_from_letter(convention[0]) | |
| i2 = _index_from_letter(convention[2]) | |
| tait_bryan = i0 != i2 | |
| if tait_bryan: | |
| central_angle = torch.asin( | |
| matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) | |
| ) | |
| else: | |
| central_angle = torch.acos(matrix[..., i0, i0]) | |
| o = ( | |
| _angle_from_tan( | |
| convention[0], convention[1], matrix[..., i2], False, tait_bryan | |
| ), | |
| central_angle, | |
| _angle_from_tan( | |
| convention[2], convention[1], matrix[..., i0, :], True, tait_bryan | |
| ), | |
| ) | |
| return torch.stack(o, -1) | |
| def computeFID(real_images, fake_images, fid_criterion): | |
| # metric = FrechetInceptionDistance(feature) | |
| fid_criterion.update(real_images, real=True) | |
| fid_criterion.update(fake_images, real=False) | |
| return fid_criterion.compute() | |
| class SLlog(nn.Module): | |
| def __init__(self): | |
| super(SLlog, self).__init__() | |
| def forward(self, fake1, real1): | |
| if not fake1.shape == real1.shape: | |
| _,_,H,W = real1.shape | |
| fake = F.upsample(fake, size=(H,W), mode='bilinear') | |
| # filter out invalid pixels | |
| real = real1.clone() | |
| fake = fake1.clone() | |
| N = (real>0).float().sum() | |
| mask1 = (real<=0) | |
| mask2 = (fake<=0) | |
| mask3 = mask1+mask2 | |
| # mask = 1.0 - (mask3>0).float() | |
| mask = (mask3>0) | |
| fake[mask] = 1. | |
| real[mask] = 1. | |
| loss_ = torch.log(real)-torch.log(fake) | |
| loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2)) | |
| # loss = 100.* torch.sum( torch.abs(torch.log(real)-torch.log(fake)) ) / N | |
| return loss | |
| class RMSE_log(nn.Module): | |
| def __init__(self, use_cuda): | |
| super(RMSE_log, self).__init__() | |
| self.eps = 1e-8 | |
| self.use_cuda = use_cuda | |
| def forward(self, fake, real): | |
| mask = real<1. | |
| n,_,h,w = real.size() | |
| fake = F.upsample(fake, size=(h,w), mode='bilinear') | |
| fake += self.eps | |
| N = len(real[mask]) | |
| loss = torch.sqrt( torch.sum( torch.abs(torch.log(real[mask])-torch.log(fake[mask])) ** 2 ) / N ) | |
| return loss | |
| def depth_to_disp(depth, min_disp=0.00001, max_disp = 1.000001): | |
| """Convert network's sigmoid output into depth prediction | |
| The formula for this conversion is given in the 'additional considerations' | |
| section of the paper. | |
| """ | |
| min_depth = 1 / max_disp | |
| max_depth = 1 / min_disp | |
| scaled_depth = min_depth + (max_depth - min_depth) * depth | |
| disp = 1 / scaled_depth | |
| return scaled_depth, disp | |
| def disp_to_depth(disp, min_depth, max_depth): | |
| """Convert network's sigmoid output into depth prediction | |
| The formula for this conversion is given in the 'additional considerations' | |
| section of the paper. | |
| """ | |
| min_disp = 1 / max_depth | |
| max_disp = 1 / min_depth | |
| scaled_disp = min_disp + (max_disp - min_disp) * disp | |
| depth = 1 / scaled_disp | |
| return scaled_disp, depth | |
| def disp_to_depth_no_scaling(disp): | |
| """Convert network's sigmoid output into depth prediction | |
| The formula for this conversion is given in the 'additional considerations' | |
| section of the paper. | |
| """ | |
| depth = 1 / (disp + 1e-7) | |
| return depth | |
| def transformation_from_parameters(axisangle, translation, invert=False): | |
| """Convert the network's (axisangle, translation) output into a 4x4 matrix | |
| """ | |
| R = rot_from_axisangle(axisangle) | |
| t = translation.clone() | |
| if invert: | |
| R = R.transpose(1, 2) # uncomment beore running | |
| t *= -1 | |
| T = get_translation_matrix(t) | |
| if invert: | |
| M = torch.matmul(R, T) | |
| else: | |
| M = torch.matmul(T, R) | |
| return M | |
| def transformation_from_parameters_euler(euler, translation, invert=False): | |
| """Convert the network's (axisangle, translation) output into a 4x4 matrix | |
| """ | |
| # R = torch.transpose(euler_angles_to_matrix(euler, 'ZYX'), 0, 1).permute(1, 0, 2) # to match with scipy euler = -euler and transpose of this | |
| R = euler_angles_to_matrix(euler, 'ZYX') # to match with scipy euler = -euler and transpose of this | |
| t = translation.clone() | |
| if invert: | |
| R = R.transpose(1, 2) | |
| t *= -1 | |
| T = get_translation_matrix(t) | |
| if invert: | |
| M = torch.matmul(R, T) | |
| else: | |
| M = torch.matmul(T, R) | |
| return M | |
| def get_translation_matrix(translation_vector): | |
| """Convert a translation vector into a 4x4 transformation matrix | |
| """ | |
| T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) | |
| t = translation_vector.contiguous().view(-1, 3, 1) | |
| T[:, 0, 0] = 1 | |
| T[:, 1, 1] = 1 | |
| T[:, 2, 2] = 1 | |
| T[:, 3, 3] = 1 | |
| T[:, :3, 3, None] = t | |
| return T | |
| def rot_from_euler(vec): | |
| rot = R.from_euler('zyx', vec, degrees=True) | |
| return | |
| def rot_from_axisangle(vec): | |
| """Convert an axisangle rotation into a 4x4 transformation matrix | |
| (adapted from https://github.com/Wallacoloo/printipi) | |
| Input 'vec' has to be Bx1x3 | |
| """ | |
| angle = torch.norm(vec, 2, 2, True) | |
| axis = vec / (angle + 1e-7) | |
| ca = torch.cos(angle) | |
| sa = torch.sin(angle) | |
| C = 1 - ca | |
| x = axis[..., 0].unsqueeze(1) | |
| y = axis[..., 1].unsqueeze(1) | |
| z = axis[..., 2].unsqueeze(1) | |
| xs = x * sa | |
| ys = y * sa | |
| zs = z * sa | |
| xC = x * C | |
| yC = y * C | |
| zC = z * C | |
| xyC = x * yC | |
| yzC = y * zC | |
| zxC = z * xC | |
| rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) | |
| rot[:, 0, 0] = torch.squeeze(x * xC + ca) | |
| rot[:, 0, 1] = torch.squeeze(xyC - zs) | |
| rot[:, 0, 2] = torch.squeeze(zxC + ys) | |
| rot[:, 1, 0] = torch.squeeze(xyC + zs) | |
| rot[:, 1, 1] = torch.squeeze(y * yC + ca) | |
| rot[:, 1, 2] = torch.squeeze(yzC - xs) | |
| rot[:, 2, 0] = torch.squeeze(zxC - ys) | |
| rot[:, 2, 1] = torch.squeeze(yzC + xs) | |
| rot[:, 2, 2] = torch.squeeze(z * zC + ca) | |
| rot[:, 3, 3] = 1 | |
| return rot | |
| class ConvBlock(nn.Module): | |
| """Layer to perform a convolution followed by ELU | |
| """ | |
| def __init__(self, in_channels, out_channels): | |
| super(ConvBlock, self).__init__() | |
| self.conv = Conv3x3(in_channels, out_channels) | |
| self.nonlin = nn.ELU(inplace=True) | |
| def forward(self, x): | |
| out = self.conv(x) | |
| out = self.nonlin(out) | |
| return out | |
| def batchNorm(num_ch_dec): | |
| return nn.BatchNorm2d(num_ch_dec) | |
| class Conv3x3(nn.Module): | |
| """Layer to pad and convolve input | |
| """ | |
| def __init__(self, in_channels, out_channels, use_refl=True): | |
| super(Conv3x3, self).__init__() | |
| if use_refl: | |
| self.pad = nn.ReflectionPad2d(1) | |
| else: | |
| self.pad = nn.ZeroPad2d(1) | |
| self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) | |
| def forward(self, x): | |
| out = self.pad(x) | |
| out = self.conv(out) | |
| return out | |
| class BackprojectDepth(nn.Module): | |
| """Layer to transform a depth image into a point cloud | |
| """ | |
| def __init__(self, batch_size, height, width): | |
| super(BackprojectDepth, self).__init__() | |
| self.batch_size = batch_size | |
| self.height = height | |
| self.width = width | |
| meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') | |
| self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) | |
| self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), | |
| requires_grad=False) | |
| self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), | |
| requires_grad=False) | |
| self.pix_coords = torch.unsqueeze(torch.stack( | |
| [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) | |
| self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) | |
| self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), | |
| requires_grad=False) | |
| def forward(self, depth, inv_K): | |
| cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) | |
| cam_points = depth.view(self.batch_size, 1, -1) * cam_points | |
| cam_points = torch.cat([cam_points, self.ones], 1) | |
| return cam_points | |
| class Project3D(nn.Module): | |
| """Layer which projects 3D points into a camera with intrinsics K and at position T | |
| """ | |
| def __init__(self, batch_size, height, width, eps=1e-7): | |
| super(Project3D, self).__init__() | |
| self.batch_size = batch_size | |
| self.height = height | |
| self.width = width | |
| self.eps = eps | |
| def forward(self, points, K, T): | |
| P = torch.matmul(K, T)[:, :3, :] | |
| cam_points = torch.matmul(P, points) | |
| pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) | |
| pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) | |
| pix_coords = pix_coords.permute(0, 2, 3, 1) | |
| pix_coords[..., 0] /= self.width - 1 | |
| pix_coords[..., 1] /= self.height - 1 | |
| pix_coords = (pix_coords - 0.5) * 2 | |
| return pix_coords | |
| def upsample(x): | |
| """Upsample input tensor by a factor of 2 | |
| """ | |
| return F.interpolate(x, scale_factor=2, mode="nearest") | |
| class deconv(nn.Module): | |
| """Layer to perform a convolution followed by ELU | |
| """ | |
| def __init__(self, ch_in, ch_out): | |
| super(deconv, self).__init__() | |
| self.deconvlayer = nn.ConvTranspose2d(ch_in, ch_out, 3, stride=2, padding=1) | |
| def forward(self, x): | |
| out = self.deconvlayer(x) | |
| return out | |
| def get_smooth_loss_gauss_mask(disp, img, gauss_mask): | |
| """Computes the smoothness loss for a disparity image | |
| The color image is used for edge-aware smoothness | |
| """ | |
| grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) | |
| grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) | |
| # weighted mean | |
| # grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])*gauss_mask[:, :, :, :-1], 1, keepdim=True) | |
| # grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])*gauss_mask[:, :, :-1, :], 1, keepdim=True) | |
| grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) | |
| grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) | |
| grad_disp_x *= torch.exp(-grad_img_x) | |
| grad_disp_y *= torch.exp(-grad_img_y) | |
| # take weighted mean | |
| grad_disp_x*=gauss_mask[:, :, :, :-1] | |
| grad_disp_y*=gauss_mask[:, :, :-1, :] | |
| return grad_disp_x.mean() + grad_disp_y.mean() | |
| def get_smooth_loss(disp, img): | |
| """Computes the smoothness loss for a disparity image | |
| The color image is used for edge-aware smoothness | |
| """ | |
| grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) | |
| grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) | |
| grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) | |
| grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) | |
| grad_disp_x *= torch.exp(-grad_img_x) | |
| grad_disp_y *= torch.exp(-grad_img_y) | |
| return grad_disp_x.mean() + grad_disp_y.mean() | |
| class SSIM(nn.Module): | |
| """Layer to compute the SSIM loss between a pair of images | |
| """ | |
| def __init__(self): | |
| super(SSIM, self).__init__() | |
| self.mu_x_pool = nn.AvgPool2d(3, 1) | |
| self.mu_y_pool = nn.AvgPool2d(3, 1) | |
| self.sig_x_pool = nn.AvgPool2d(3, 1) | |
| self.sig_y_pool = nn.AvgPool2d(3, 1) | |
| self.sig_xy_pool = nn.AvgPool2d(3, 1) | |
| self.refl = nn.ReflectionPad2d(1) | |
| self.C1 = 0.01 ** 2 | |
| self.C2 = 0.03 ** 2 | |
| def forward(self, x, y): | |
| x = self.refl(x) | |
| y = self.refl(y) | |
| mu_x = self.mu_x_pool(x) | |
| mu_y = self.mu_y_pool(y) | |
| sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 | |
| sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 | |
| sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y | |
| SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) | |
| SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) | |
| return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) | |
| def compute_depth_errors(gt, pred): | |
| """Computation of error metrics between predicted and ground truth depths | |
| """ | |
| thresh = torch.max((gt / pred), (pred / gt)) | |
| a1 = (thresh < 1.25 ).float().mean() | |
| a2 = (thresh < 1.25 ** 2).float().mean() | |
| a3 = (thresh < 1.25 ** 3).float().mean() | |
| rmse = (gt - pred) ** 2 | |
| rmse = torch.sqrt(rmse.mean()) | |
| rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 | |
| rmse_log = torch.sqrt(rmse_log.mean()) | |
| abs_rel = torch.mean(torch.abs(gt - pred) / gt) | |
| sq_rel = torch.mean((gt - pred) ** 2 / gt) | |
| return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 | |
| """ Parts of the U-Net model """ | |
| class InstanceNormDoubleConv(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
| nn.InstanceNorm2d(mid_channels, affine = True), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class DoubleConv(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(mid_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class DoubleConvIN(nn.Module): | |
| """(convolution => [BN] => ReLU) * 2""" | |
| def __init__(self, in_channels, out_channels, mid_channels=None): | |
| super().__init__() | |
| if not mid_channels: | |
| mid_channels = out_channels | |
| self.double_conv = nn.Sequential( | |
| nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), | |
| nn.InstanceNorm2d(mid_channels,affine = True).to('cuda'), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), | |
| nn.InstanceNorm2d(out_channels,affine = True).to('cuda'), | |
| nn.ReLU(inplace=True)) | |
| def forward(self, x): | |
| return self.double_conv(x) | |
| class Down(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.maxpool_conv = nn.Sequential( | |
| nn.MaxPool2d(2), | |
| DoubleConv(in_channels, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.maxpool_conv(x) | |
| class DownIN(nn.Module): | |
| """Downscaling with maxpool then double conv""" | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.maxpool_conv = nn.Sequential( | |
| nn.MaxPool2d(2), | |
| DoubleConvIN(in_channels, out_channels) | |
| ) | |
| def forward(self, x): | |
| return self.maxpool_conv(x) | |
| class Up(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_channels, out_channels, bilinear=True): | |
| super().__init__() | |
| # if bilinear, use the normal convolutions to reduce the number of channels | |
| if bilinear: | |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) | |
| else: | |
| self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| # input is CHW | |
| diffY = x2.size()[2] - x1.size()[2] | |
| diffX = x2.size()[3] - x1.size()[3] | |
| x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2]) | |
| # if you have padding issues, see | |
| # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
| # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| class OutConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(OutConv, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class UpIN(nn.Module): | |
| """Upscaling then double conv""" | |
| def __init__(self, in_channels, out_channels, bilinear=True): | |
| super().__init__() | |
| # if bilinear, use the normal convolutions to reduce the number of channels | |
| if bilinear: | |
| self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| self.conv = DoubleConvIN(in_channels, out_channels, in_channels // 2) | |
| else: | |
| self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) | |
| self.conv = DoubleConvIN(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| # input is CHW | |
| diffY = x2.size()[2] - x1.size()[2] | |
| diffX = x2.size()[3] - x1.size()[3] | |
| x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | |
| diffY // 2, diffY - diffY // 2]) | |
| # if you have padding issues, see | |
| # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
| # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
| x = torch.cat([x2, x1], dim=1) | |
| return self.conv(x) | |
| # def gaussian_fn(M, std): | |
| # n = torch.arange(0, M) - (M - 1.0) / 2.0 | |
| # sig2 = 2 * std * std | |
| # w = torch.exp(-n ** 2 / sig2) | |
| # return w | |
| # def gkern(kernlen=256, std=128): | |
| # """Returns a 2D Gaussian kernel array.""" | |
| # gkern1d = gaussian_fn(kernlen, std=std) | |
| # gkern2d = torch.outer(gkern1d, gkern1d) | |
| # return gkern2d | |
| # A = np.random.rand(256*256).reshape([256,256]) | |
| # A = torch.from_numpy(A) | |
| # guassian_filter = gkern(256, std=32) | |
| # class GaussianLayer(nn.Module): | |
| # def __init__(self): | |
| # super(GaussianLayer, self).__init__() | |
| # self.seq = nn.Sequential( | |
| # nn.ReflectionPad2d(10), | |
| # nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3) | |
| # ) | |
| # self.weights_init() | |
| # def forward(self, x): | |
| # return self.seq(x) | |
| # def weights_init(self): | |
| # n= np.zeros((21,21)) | |
| # n[10,10] = 1 | |
| # k = scipy.ndimage.gaussian_filter(n,sigma=3) | |
| # for name, f in self.named_parameters(): | |
| # f.data.copy_(torch.from_numpy(k)) | |