Spaces:
Build error
Build error
File size: 6,707 Bytes
423ed4d c401192 423ed4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import torch
from torch import nn
import numpy as np
from .. import utils
## Conv + bn + relu
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=None, mode='conv', use_bn=True):
super(ConvBlock, self).__init__()
self.use_bn = use_bn
if padding == None:
padding = (kernel_size - 1) // 2
pass
if mode == 'conv':
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
elif mode == 'deconv':
self.conv = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
elif mode == 'conv_3d':
self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
elif mode == 'deconv_3d':
self.conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
else:
print('conv mode not supported', mode)
exit(1)
pass
if self.use_bn:
if '3d' not in mode:
self.bn = nn.BatchNorm2d(out_planes)
else:
self.bn = nn.BatchNorm3d(out_planes)
pass
pass
self.relu = nn.ReLU(inplace=True)
return
def forward(self, inp):
#return self.relu(self.conv(inp))
if self.use_bn:
return self.relu(self.bn(self.conv(inp)))
else:
return self.relu(self.conv(inp))
## The pyramid module from pyramid scene parsing
class PyramidModule(nn.Module):
def __init__(self, options, in_planes, middle_planes, scales=[32, 16, 8, 4]):
super(PyramidModule, self).__init__()
self.pool_1 = torch.nn.AvgPool2d((scales[0] * options.height // options.width, scales[0]))
self.pool_2 = torch.nn.AvgPool2d((scales[1] * options.height // options.width, scales[1]))
self.pool_3 = torch.nn.AvgPool2d((scales[2] * options.height // options.width, scales[2]))
self.pool_4 = torch.nn.AvgPool2d((scales[3] * options.height // options.width, scales[3]))
self.conv_1 = ConvBlock(in_planes, middle_planes, kernel_size=1, use_bn=False)
self.conv_2 = ConvBlock(in_planes, middle_planes, kernel_size=1)
self.conv_3 = ConvBlock(in_planes, middle_planes, kernel_size=1)
self.conv_4 = ConvBlock(in_planes, middle_planes, kernel_size=1)
self.upsample = torch.nn.Upsample(size=(scales[0] * options.height // options.width, scales[0]), mode='bilinear')
return
def forward(self, inp):
x_1 = self.upsample(self.conv_1(self.pool_1(inp)))
x_2 = self.upsample(self.conv_2(self.pool_2(inp)))
x_3 = self.upsample(self.conv_3(self.pool_3(inp)))
x_4 = self.upsample(self.conv_4(self.pool_4(inp)))
out = torch.cat([inp, x_1, x_2, x_3, x_4], dim=1)
return out
## The module to compute plane depths from plane parameters
def calcPlaneDepthsModule(width, height, planes, metadata, return_ranges=False):
urange = (torch.arange(width, dtype=torch.float32).cuda().view((1, -1)).repeat(height, 1) / (float(width) + 1) * (metadata[4] + 1) - metadata[2]) / metadata[0]
vrange = (torch.arange(height, dtype=torch.float32).cuda().view((-1, 1)).repeat(1, width) / (float(height) + 1) * (metadata[5] + 1) - metadata[3]) / metadata[1]
ranges = torch.stack([urange, torch.ones(urange.shape).cuda(), -vrange], dim=-1)
planeOffsets = torch.norm(planes, dim=-1, keepdim=True)
planeNormals = planes / torch.clamp(planeOffsets, min=1e-4)
normalXYZ = torch.sum(ranges.unsqueeze(-2) * planeNormals.unsqueeze(-3).unsqueeze(-3), dim=-1)
normalXYZ[normalXYZ == 0] = 1e-4
planeDepths = planeOffsets.squeeze(-1).unsqueeze(-2).unsqueeze(-2) / normalXYZ
planeDepths = torch.clamp(planeDepths, min=0, max=MAX_DEPTH)
if return_ranges:
return planeDepths, ranges
return planeDepths
## The module to compute depth from plane information
def calcDepthModule(width, height, planes, segmentation, non_plane_depth, metadata):
planeDepths = calcPlaneDepthsModule(width, height, planes, metadata)
allDepths = torch.cat([planeDepths.transpose(-1, -2).transpose(-2, -3), non_plane_depth], dim=1)
return torch.sum(allDepths * segmentation, dim=1)
## Compute matching with the auction-based approximation algorithm
def assignmentModule(W):
O = calcAssignment(W.detach().cpu().numpy())
return torch.from_numpy(O).cuda()
def calcAssignment(W):
numOwners = int(W.shape[0])
numGoods = int(W.shape[1])
P = np.zeros(numGoods)
O = np.full(shape=(numGoods, ), fill_value=-1)
delta = 1.0 / (numGoods + 1)
queue = list(range(numOwners))
while len(queue) > 0:
ownerIndex = queue[0]
queue = queue[1:]
weights = W[ownerIndex]
goodIndex = (weights - P).argmax()
if weights[goodIndex] >= P[goodIndex]:
if O[goodIndex] >= 0:
queue.append(O[goodIndex])
pass
O[goodIndex] = ownerIndex
P[goodIndex] += delta
pass
continue
return O
## Get one-hot tensor
def oneHotModule(inp, depth):
inpShape = [int(size) for size in inp.shape]
inp = inp.view(-1)
out = torch.zeros(int(inp.shape[0]), depth).cuda()
out.scatter_(1, inp.unsqueeze(-1), 1)
out = out.view(inpShape + [depth])
return out
## Warp image
def warpImages(options, planes, images, transformations, metadata):
planeDepths, ranges = calcPlaneDepthsModule(options.width, options.height, planes, metadata, return_ranges=True)
print(planeDepths.shape, ranges.shape, transformations.shape)
exit(1)
XYZ = planeDepths.unsqueeze(-1) * ranges.unsqueeze(-2)
XYZ = torch.cat([XYZ, torch.ones([int(size) for size in XYZ.shape[:-1]] + [1]).cuda()], dim=-1)
XYZ = torch.matmul(XYZ.unsqueeze(-3), transformations.unsqueeze(-4).unsqueeze(-4))
UVs = XYZ[:, :, :, :, :, :2] / XYZ[:, :, :, :, :, 2:3]
UVs = (UVs * metadata[:2] + metadata[2:4]) / metadata[4:6] * 2 - 1
warpedImages = []
for imageIndex in range(options.numNeighborImages):
warpedImage = []
image = images[:, imageIndex]
for planeIndex in range(options.numOutputPlanes):
warpedImage.append(F.grid_sample(image, UVs[:, :, :, imageIndex, planeIndex]))
continue
warpedImages.append(torch.stack(warpedImage, 1))
continue
warpedImages = torch.stack(warpedImages, 2)
return warpedImages
|