File size: 5,652 Bytes
f06f310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from opt_einsum import contract
class DispHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256, output_dim=1):
super(DispHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU_NoCTX(nn.Module):
def __init__(self, hidden_dim, input_dim, kernel_size=3):
super(ConvGRU_NoCTX, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
self._initialize_weights()
def forward(self, h, *x_list):
x = torch.cat(x_list, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx) )
r = torch.sigmoid(self.convr(hx) )
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) )
h = (1-z) * h + z * q
return h
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight)
if m.bias is not None:
m.bias.data.zero_()
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, *x):
# horizontal
x = torch.cat(x, dim=1)
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class BasicShiftEncoder(nn.Module):
def __init__(self, args):
super(BasicShiftEncoder, self).__init__()
self.args = args
cor_planes = args.corr_levels * (2*args.corr_radius + 1)
self.convc1 = nn.Conv2d(cor_planes, 64, 1, padding=0)
self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
self.convf1 = nn.Conv2d(1, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
self.conv = nn.Conv2d(64+64, 128-1, 3, padding=1)
def forward(self, disp, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
dis = F.relu(self.convf1(disp))
dis = F.relu(self.convf2(dis))
cor_dis = torch.cat([cor, dis], dim=1)
out = F.relu(self.conv(cor_dis))
return torch.cat([out, disp], dim=1)
def pool2x(x):
return F.avg_pool2d(x, 3, stride=2, padding=1)
def pool4x(x):
return F.avg_pool2d(x, 5, stride=4, padding=1)
def interp(x, dest):
interp_args = {'mode': 'bilinear', 'align_corners': True}
return F.interpolate(x, dest.shape[2:], **interp_args)
class DispBasicMultiUpdateBlock_NoCTX(nn.Module):
def __init__(self, args, hidden_dims=[]):
super(DispBasicMultiUpdateBlock_NoCTX, self).__init__()
self.args = args
self.encoder = BasicShiftEncoder(args)
encoder_output_dim = 128
self.gru08 = ConvGRU_NoCTX(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (args.n_gru_layers > 1))
self.gru16 = ConvGRU_NoCTX(hidden_dims[1], hidden_dims[0] * (args.n_gru_layers == 3) + hidden_dims[2])
self.gru32 = ConvGRU_NoCTX(hidden_dims[0], hidden_dims[1])
self.disp_head = DispHead(hidden_dims[2], hidden_dim=256, output_dim=1)
factor = 2**self.args.n_downsample
self.mask = nn.Sequential(
nn.Conv2d(hidden_dims[2], 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, (factor**2)*9, 1, padding=0))
def forward(self, net, corr=None, disp=None, iter08=True, iter16=True, iter32=True, update=True):
if iter32:
net[2] = self.gru32(net[2], pool2x(net[1]))
if iter16:
if self.args.n_gru_layers > 2:
net[1] = self.gru16(net[1], pool2x(net[0]), interp(net[2], net[1]))
else:
net[1] = self.gru16(net[1], pool2x(net[0]))
if iter08:
motion_features = self.encoder(disp, corr)
if self.args.n_gru_layers > 1:
net[0] = self.gru08(net[0], motion_features, interp(net[1], net[0]))
else:
net[0] = self.gru08(net[0], motion_features)
if not update:
return net
delta_disp = self.disp_head(net[0])
# scale mask to balence gradients
mask = .25 * self.mask(net[0])
return net, mask, delta_disp
|