File size: 15,459 Bytes
f06f310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
import sys
import logging
import numpy as np
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F

from core.update import ManifoldBasicMultiUpdateBlock
from core.extractor import BasicEncoder, MultiBasicEncoder, ResidualBlock
from core.corr import CorrBlock1D, PytorchAlternateCorrBlock1D, CorrBlockFast1D, AlternateCorrBlock
from core.utils.utils import coords_grid, upflow8, LoggerCommon
from core.confidence import OffsetConfidence
from core.refinement import Refinement, UpdateHistory
from core import geometry as GEO
from core.utils.plane import get_pos, convert2patch, predict_disp

logger = LoggerCommon("ARCHI")

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass
        def __enter__(self):
            pass
        def __exit__(self, *args):
            pass

class RAFTStereo(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        
        context_dims = args.hidden_dims

        self.cnet = MultiBasicEncoder(output_dim=[args.hidden_dims, context_dims], norm_fn=args.context_norm, downsample=args.n_downsample)
        self.update_block = ManifoldBasicMultiUpdateBlock(self.args, hidden_dims=args.hidden_dims)

        self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.args.n_gru_layers)])

        if args.shared_backbone:
            self.conv2 = nn.Sequential(
                ResidualBlock(128, 128, 'instance', stride=1),
                nn.Conv2d(128, 256, 3, padding=1))
        else:
            self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', downsample=args.n_downsample)
        
        if args.confidence:
            self.confidence_computer = OffsetConfidence(args)

        if args.geo_estimator=="geometry_mlp":
            self.geometry_builder = GEO.Geometry_MLP(args)
        elif args.geo_estimator=="geometry_conv":
            self.geometry_builder = GEO.Geometry_Conv(args)
        elif args.geo_estimator=="geometry_conv_split":
            self.geometry_builder = GEO.Geometry_Conv_Split(args)
        
        if args.refinement is not None and len(args.refinement)>0:
            if self.args.slant is None or len(self.args.slant)==0 :
                dim_disp = 1
            elif self.args.slant in ["slant", "slant_local"] :
                dim_disp = 6

            if args.refinement.lower()=="refinement":
                self.refine = Refinement(args, in_chans=256, dim_fea=96, dim_disp=dim_disp)
            else:
                raise Exception("No such refinement: {}".format(args.refinement))
        
        if self.args.update_his:
            self.update_hist = UpdateHistory(args, 128, dim_disp)

        logger.info(f"RAFTStereo ~ " +\
                    f"Confidence: {args.confidence}, offset_memory_size: {args.offset_memory_size}, " +\
                    f"offset_memory_last_iter: {args.offset_memory_last_iter}, " +\
                    f"slant: {args.slant}, slant_norm: {args.slant_norm}, " +\
                    f"geo estimator: {args.geo_estimator}, geo_fusion: {args.geo_fusion}, " +\
                    f"refine: {args.refinement}, refine_win_size: {args.refine_win_size}, num_heads:{args.num_heads}, " +\
                    f"split_win: {args.split_win}, refine_start_itr: {args.refine_start_itr}, " +\
                    f"update_his: {args.update_his}, U_thold: {args.U_thold}, " +\
                    f"stop_freeze_bn: {args.stop_freeze_bn}" )

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, _, H, W = img.shape

        coords0 = coords_grid(N, H, W).to(img.device)
        coords1 = coords_grid(N, H, W).to(img.device)

        return coords0, coords1

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, D, H, W = flow.shape
        factor = 2 ** self.args.n_downsample
        mask = mask.view(N, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(factor * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
        up_flow = torch.sum(mask * up_flow, dim=2)

        img_coord = None
        if self.args.geo_estimator is not None and len(self.args.geo_estimator)>0:
            img_coord = get_pos(H*factor, W*factor, disp=None,
                                slant=self.args.slant,
                                slant_norm=self.args.slant_norm,
                                patch_size=factor,
                                device=flow.device)                                                # (1,2,H*factor,W*factor)
            img_coord = img_coord.repeat(N,1,1,1)
        
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, D, factor*H, factor*W), img_coord
    
    def upsample_geo(self, mask=None, mask_disp=None, params=None):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, D, H, W = params.shape
        factor = 2 ** self.args.n_downsample
        if mask is not None:
            mask = mask.view(N, 1, 9, factor, factor, H, W)
            mask = torch.softmax(mask, dim=2)                                                    # (B,1,9,factor,factor,H,W)
        if mask_disp is not None:
            mask_disp = mask_disp.view(N, 1, 9, factor, factor, H, W)
            mask_disp = torch.softmax(mask_disp, dim=2)                                          # (B,1,9,factor,factor,H,W)

        # d_p = a_q\cdot\Delta u_{q\to p} + b_q\cdot\Delta v_{q\to p} + d_q
        delta_pq = get_pos(H*factor, W*factor, disp=None,
                            slant=self.args.slant,
                            slant_norm=self.args.slant_norm,
                            patch_size=factor,
                            device=params.device)                                                # (1,2,H*factor,W*factor)
        patch_delta_pq = convert2patch(delta_pq, patch_size=factor, div_last=False).detach()     # (1,2,factor*factor,H,W)

        disp = predict_disp(params, patch_delta_pq, patch_size=factor, mul_last=True)            # (B,factor*factor,H,W)
        
        if mask_disp is not None:
            disp = F.unfold(disp, [3,3], padding=1)                                                  # (B,factor*factor*9,H,W)
            disp = disp.view(N, 1, factor, factor, 9, H, W)                                          # (B,1,factor,factor,9,H,W)
            disp = disp.permute((0,1,4,2,3,5,6))                                                     # (B,1,9,factor,factor,H,W)
            disp = torch.sum(mask_disp * disp, dim=2)                                                     # (B,1,factor,factor,H,W)
            disp = disp.permute(0, 1, 4, 2, 5, 3)                                                    # (B,1,H,factor,W,factor)
            return disp.reshape(N, 1, factor*H, factor*W)
        
        elif mask is not None:
            disp = F.unfold(disp, [3,3], padding=1)                                                  # (B,factor*factor*9,H,W)
            disp = disp.view(N, 1, factor, factor, 9, H, W)                                          # (B,1,factor,factor,9,H,W)
            disp = disp.permute((0,1,4,2,3,5,6))                                                     # (B,1,9,factor,factor,H,W)
            disp = torch.sum(mask * disp, dim=2)                                                     # (B,1,factor,factor,H,W)
            disp = disp.permute(0, 1, 4, 2, 5, 3)                                                    # (B,1,H,factor,W,factor)
            return disp.reshape(N, 1, factor*H, factor*W)

        disp = F.fold(disp.flatten(-2,-1), (H*factor,W*factor), kernel_size=factor, stride=factor).view(N,1,H*factor,W*factor)
        return disp


    def forward(self, image1, image2, iters=12, flow_init=None, 
                test_mode=False, vis_mode=False, enable_refinement=True):
        """ Estimate optical flow between pair of frames """

        image1 = (2 * (image1 / 255.0) - 1.0).contiguous()
        image2 = (2 * (image2 / 255.0) - 1.0).contiguous()

        # run the context network
        with autocast(enabled=self.args.mixed_precision):
            if self.args.shared_backbone:
                *cnet_list, x = self.cnet(torch.cat((image1, image2), dim=0), dual_inp=True, num_layers=self.args.n_gru_layers)
                fmap1, fmap2 = self.conv2(x).split(dim=0, split_size=x.shape[0]//2)
            else:
                cnet_list = self.cnet(image1, num_layers=self.args.n_gru_layers)
                fmap1, fmap2 = self.fnet([image1, image2])
            net_list = [torch.tanh(x[0]) for x in cnet_list]
            inp_list = [torch.relu(x[1]) for x in cnet_list]

            # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning 
            inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]

        if self.args.corr_implementation == "reg": # Default
            corr_block = CorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "alt": # More memory efficient than reg
            corr_block = PytorchAlternateCorrBlock1D
            fmap1, fmap2 = fmap1.float(), fmap2.float()
        elif self.args.corr_implementation == "reg_cuda": # Faster version of reg
            corr_block = CorrBlockFast1D
        elif self.args.corr_implementation == "alt_cuda": # Faster version of alt
            corr_block = AlternateCorrBlock
        corr_fn = corr_block(fmap1, fmap2, radius=self.args.corr_radius, num_levels=self.args.corr_levels)

        coords0, coords1 = self.initialize_flow(net_list[0])

        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        disp_predictions = []
        disp_predictions_refine = []
        params_list = []
        params_list_refine = []
        confidence_list = []
        offset_memory = []
        for itr in range(iters):
            coords1 = coords1.detach()
            corr = corr_fn(coords1) # index correlation volume
            flow = coords1 - coords0

            with autocast(enabled=self.args.mixed_precision):
                ## first-stage in geometry estimation
                if self.args.n_gru_layers == 3 and self.args.slow_fast_gru: # Update low-res GRU
                    net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
                if self.args.n_gru_layers >= 2 and self.args.slow_fast_gru:# Update low-res GRU and mid-res GRU
                    net_list = self.update_block(net_list, inp_list, iter32=self.args.n_gru_layers==3, iter16=True, iter08=False, update=False)
                net_list, up_mask, delta_flow, up_mask_disp = self.update_block(net_list, inp_list, corr, flow, iter32=self.args.n_gru_layers==3, iter16=self.args.n_gru_layers>=2)

                ## region detection: acquire confidence
                if self.args.confidence:
                    offset_memory.append(delta_flow[:,0:2])
                    if itr<self.args.offset_memory_size:
                        confidence = None
                    else:
                        if self.args.offset_memory_last_iter<0 or itr<=self.args.offset_memory_last_iter:
                            input_offset_mem = offset_memory[-self.args.offset_memory_size:]
                        else:
                            start_itr = self.args.offset_memory_last_iter - self.args.offset_memory_size
                            end_itr   = self.args.offset_memory_last_iter
                            input_offset_mem = offset_memory[start_itr:end_itr]
                        confidence = self.confidence_computer(inp_list[0], input_offset_mem)
                else:
                    confidence = None
                confidence_list.append(confidence)

            # in stereo mode, project flow onto epipolar
            delta_flow[:,1] = 0.0

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow
            flow = coords1 - coords0

            # We do not need to upsample or output intermediate results in test_mode for raftStereo
            if test_mode and itr < iters-1 and \
               (self.args.refinement is None or len(self.args.refinement)==0):
                continue
            
            # upsample disparity map
            if up_mask is None:
                flow_up = upflow8(flow)
            else:
                flow_up, img_coord = self.upsample_flow(flow, up_mask)
            flow_up = flow_up[:,:1]
            flow_predictions.append(flow_up)

            # second-stage in geometry estimation
            geo_params = None
            disparity = -flow[:,:1]
            if self.args.geo_estimator is not None and len(self.args.geo_estimator)>0:
                geo_params = self.geometry_builder(img_coord, -flow_up, disparity)
                
                # disp_up = self.upsample_geo(up_mask, params=geo_params)
                disp_up = self.upsample_geo(mask=None, mask_disp=up_mask_disp, params=geo_params)
            params_list.append(geo_params)
            disp_predictions.append(disp_up)

            ## curvature-aware propagation
            disparity_refine = None
            geo_params_refine = None
            if self.args.refinement is not None and len(self.args.refinement)>0 and enable_refinement:
                if itr>=self.args.refine_start_itr:
                    geo_params_refine = self.refine(geo_params, inp_list[0], confidence, 
                                            if_shift=(itr-self.args.refine_start_itr)%2>0)
                    coords1 = coords0 - geo_params_refine[:,:1]
                    disparity_refine = geo_params_refine[:,:1]
                    ### update hidden state
                    if self.args.update_his:
                        net_list[0] = self.update_hist(net_list[0], -disparity_refine)
            params_list_refine.append(geo_params_refine)
            
            # upsample refinement
            disp_up_refine = None
            if geo_params_refine is not None:
                # disp_up_refine = self.upsample_geo(up_mask, params=geo_params_refine)
                disp_up_refine = self.upsample_geo(mask=None, mask_disp=up_mask_disp, params=geo_params_refine)
                # disp_up_refine = disp_up_refine[:,:1]
            disp_predictions_refine.append(disp_up_refine)

        if test_mode:
            if self.args.refinement is not None and len(self.args.refinement)>0 and enable_refinement:
                return coords1 - coords0, flow_up_refine
            return coords1 - coords0, flow_up
            # return coords1 - coords0, -disp_up

        if vis_mode:
            return flow_predictions, disp_predictions, disp_predictions_refine, confidence_list

        return flow_predictions, disp_predictions, disp_predictions_refine, confidence_list, params_list, params_list_refine