import torch from torch import nn import numpy as np from .. import utils ## Conv + bn + relu class ConvBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=None, mode='conv', use_bn=True): super(ConvBlock, self).__init__() self.use_bn = use_bn if padding == None: padding = (kernel_size - 1) // 2 pass if mode == 'conv': self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) elif mode == 'deconv': self.conv = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) elif mode == 'conv_3d': self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) elif mode == 'deconv_3d': self.conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) else: print('conv mode not supported', mode) exit(1) pass if self.use_bn: if '3d' not in mode: self.bn = nn.BatchNorm2d(out_planes) else: self.bn = nn.BatchNorm3d(out_planes) pass pass self.relu = nn.ReLU(inplace=True) return def forward(self, inp): #return self.relu(self.conv(inp)) if self.use_bn: return self.relu(self.bn(self.conv(inp))) else: return self.relu(self.conv(inp)) ## The pyramid module from pyramid scene parsing class PyramidModule(nn.Module): def __init__(self, options, in_planes, middle_planes, scales=[32, 16, 8, 4]): super(PyramidModule, self).__init__() self.pool_1 = torch.nn.AvgPool2d((scales[0] * options.height // options.width, scales[0])) self.pool_2 = torch.nn.AvgPool2d((scales[1] * options.height // options.width, scales[1])) self.pool_3 = torch.nn.AvgPool2d((scales[2] * options.height // options.width, scales[2])) self.pool_4 = torch.nn.AvgPool2d((scales[3] * options.height // options.width, scales[3])) self.conv_1 = ConvBlock(in_planes, middle_planes, kernel_size=1, use_bn=False) self.conv_2 = ConvBlock(in_planes, middle_planes, kernel_size=1) self.conv_3 = ConvBlock(in_planes, middle_planes, kernel_size=1) self.conv_4 = ConvBlock(in_planes, middle_planes, kernel_size=1) self.upsample = torch.nn.Upsample(size=(scales[0] * options.height // options.width, scales[0]), mode='bilinear') return def forward(self, inp): x_1 = self.upsample(self.conv_1(self.pool_1(inp))) x_2 = self.upsample(self.conv_2(self.pool_2(inp))) x_3 = self.upsample(self.conv_3(self.pool_3(inp))) x_4 = self.upsample(self.conv_4(self.pool_4(inp))) out = torch.cat([inp, x_1, x_2, x_3, x_4], dim=1) return out ## The module to compute plane depths from plane parameters def calcPlaneDepthsModule(width, height, planes, metadata, return_ranges=False): urange = (torch.arange(width, dtype=torch.float32).cuda().view((1, -1)).repeat(height, 1) / (float(width) + 1) * (metadata[4] + 1) - metadata[2]) / metadata[0] vrange = (torch.arange(height, dtype=torch.float32).cuda().view((-1, 1)).repeat(1, width) / (float(height) + 1) * (metadata[5] + 1) - metadata[3]) / metadata[1] ranges = torch.stack([urange, torch.ones(urange.shape).cuda(), -vrange], dim=-1) planeOffsets = torch.norm(planes, dim=-1, keepdim=True) planeNormals = planes / torch.clamp(planeOffsets, min=1e-4) normalXYZ = torch.sum(ranges.unsqueeze(-2) * planeNormals.unsqueeze(-3).unsqueeze(-3), dim=-1) normalXYZ[normalXYZ == 0] = 1e-4 planeDepths = planeOffsets.squeeze(-1).unsqueeze(-2).unsqueeze(-2) / normalXYZ planeDepths = torch.clamp(planeDepths, min=0, max=MAX_DEPTH) if return_ranges: return planeDepths, ranges return planeDepths ## The module to compute depth from plane information def calcDepthModule(width, height, planes, segmentation, non_plane_depth, metadata): planeDepths = calcPlaneDepthsModule(width, height, planes, metadata) allDepths = torch.cat([planeDepths.transpose(-1, -2).transpose(-2, -3), non_plane_depth], dim=1) return torch.sum(allDepths * segmentation, dim=1) ## Compute matching with the auction-based approximation algorithm def assignmentModule(W): O = calcAssignment(W.detach().cpu().numpy()) return torch.from_numpy(O).cuda() def calcAssignment(W): numOwners = int(W.shape[0]) numGoods = int(W.shape[1]) P = np.zeros(numGoods) O = np.full(shape=(numGoods, ), fill_value=-1) delta = 1.0 / (numGoods + 1) queue = list(range(numOwners)) while len(queue) > 0: ownerIndex = queue[0] queue = queue[1:] weights = W[ownerIndex] goodIndex = (weights - P).argmax() if weights[goodIndex] >= P[goodIndex]: if O[goodIndex] >= 0: queue.append(O[goodIndex]) pass O[goodIndex] = ownerIndex P[goodIndex] += delta pass continue return O ## Get one-hot tensor def oneHotModule(inp, depth): inpShape = [int(size) for size in inp.shape] inp = inp.view(-1) out = torch.zeros(int(inp.shape[0]), depth).cuda() out.scatter_(1, inp.unsqueeze(-1), 1) out = out.view(inpShape + [depth]) return out ## Warp image def warpImages(options, planes, images, transformations, metadata): planeDepths, ranges = calcPlaneDepthsModule(options.width, options.height, planes, metadata, return_ranges=True) print(planeDepths.shape, ranges.shape, transformations.shape) exit(1) XYZ = planeDepths.unsqueeze(-1) * ranges.unsqueeze(-2) XYZ = torch.cat([XYZ, torch.ones([int(size) for size in XYZ.shape[:-1]] + [1]).cuda()], dim=-1) XYZ = torch.matmul(XYZ.unsqueeze(-3), transformations.unsqueeze(-4).unsqueeze(-4)) UVs = XYZ[:, :, :, :, :, :2] / XYZ[:, :, :, :, :, 2:3] UVs = (UVs * metadata[:2] + metadata[2:4]) / metadata[4:6] * 2 - 1 warpedImages = [] for imageIndex in range(options.numNeighborImages): warpedImage = [] image = images[:, imageIndex] for planeIndex in range(options.numOutputPlanes): warpedImage.append(F.grid_sample(image, UVs[:, :, :, imageIndex, planeIndex])) continue warpedImages.append(torch.stack(warpedImage, 1)) continue warpedImages = torch.stack(warpedImages, 2) return warpedImages