File size: 6,426 Bytes
f06f310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.update_disp import DispBasicMultiUpdateBlock
from core.extractor import BasicEncoder, MultiBasicEncoder, ResidualBlock
from core.corr import CorrBlock1D, PytorchAlternateCorrBlock1D, CorrBlockFast1D, AlternateCorrBlock
from core.utils.utils import hor_coords_grid
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFTStereoDisp(nn.Module):
def __init__(self, args):
super(RAFTStereoDisp, self).__init__()
self.args = args
context_dims = args.hidden_dims
self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn=args.context_norm, downsample=args.n_downsample)
self.update_block = DispBasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)
self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])
if args.shared_backbone:
self.conv2 = nn.Sequential(
ResidualBlock(128, 128, 'instance', stride=1),
nn.Conv2d(128, 256, 3, padding=1))
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', downsample=args.n_downsample)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_disp(self, img):
""" Disparity is represented as difference between two horizontal coordinate grids disp = hor_coords1 - hor_coords0"""
N, _, H, W = img.shape
hor_coords0 = hor_coords_grid(N, H, W).to(img.device)
hor_coords1 = hor_coords_grid(N, H, W).to(img.device)
return hor_coords0, hor_coords1
def upsample_disp(self, disp, mask):
""" Upsample disp field [H/8, W/8, 1] -> [H, W, 1] using convex combination """
N, D, H, W = disp.shape
factor = 2 ** self.args.n_downsample
mask = mask.view(N, 1, 9, factor, factor, H, W)
mask = torch.softmax(mask, dim=2)
up_disp = F.unfold(factor * disp, [3,3], padding=1)
up_disp = up_disp.view(N, D, 9, 1, 1, H, W)
up_disp = torch.sum(mask * up_disp, dim=2)
up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
return up_disp.reshape(N, D, factor*H, factor*W)
def forward(self, image1, image2, iters=12, disp_init=None, test_mode=False, vis_mode=False, intrinsic=None):
""" Estimate optical flow between pair of frames """
image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
image2 = (2 * (image2 / 255.0) - 1.0).contiguous()
# run the context network
with autocast(enabled=self.args.mixed_precision):
if self.args.shared_backbone:
*cnet_list, x = self.cnet(torch.cat((image1, image2), dim=0), dual_inp=True, num_layers=self.args.n_gru_layers)
fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
else:
# cnet_list: [[(128,248,360), (128,248,360)], [(128,124,180),(128,124,180)], [(128,62,90),(128,62,90)]]
cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
# fmap1: (128,248,360), fmap2: (128,248,360)
fmap1, fmap2 = self.fnet([image1, image2])
# from IPython import embed
# embed()
net_list = [torch.tanh(x[0]) for x in cnet_list]
inp_list = [torch.relu(x[1]) for x in cnet_list]
# Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
if self.args.corr_implementation == "reg": # Default
corr_block = CorrBlock1D
fmap1, fmap2 = fmap1.float(), fmap2.float()
elif self.args.corr_implementation == "alt": # More memory efficient than reg
corr_block = PytorchAlternateCorrBlock1D
fmap1, fmap2 = fmap1.float(), fmap2.float()
elif self.args.corr_implementation == "reg_cuda": # Faster version of reg
corr_block = CorrBlockFast1D
elif self.args.corr_implementation == "alt_cuda": # Faster version of alt
corr_block = AlternateCorrBlock
corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels)
hor_coords0, hor_coords1 = self.initialize_disp(net_list[0])
if disp_init is not None:
hor_coords1 = hor_coords1 + disp_init
disp_predictions = []
for itr in range(iters):
hor_coords1 = hor_coords1.detach()
corr = corr_fn(hor_coords1) # index correlation volume
disp = hor_coords1 - hor_coords0
with autocast(enabled=self.args.mixed_precision):
if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU
net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res GRU and mid-res GRU
net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False)
net_list, up_mask, delta_disp = self.update_block(net_list, inp_list, corr, disp, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2)
# F(t+1) = F(t) + \Delta(t)
hor_coords1 = hor_coords1 + delta_disp
# We do not need to upsample or output intermediate results in test_mode
if test_mode and itr < iters-1:
continue
# upsample predictions
disp_up = self.upsample_disp(hor_coords1 - hor_coords0, up_mask)
disp_predictions.append(disp_up)
if test_mode:
return hor_coords1 - hor_coords0, disp_up
if vis_mode:
return {"disp_predictions": disp_predictions, }
return {"disp_predictions": disp_predictions,} |