File size: 5,652 Bytes
0940df6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.models.layers import trunc_normal_
from models.common import UpConv
from models.convformer import convformer
from models.attention_blocks import MatchAttentionBlock
from models.cost_volume import GlobalCorrelation

class MatchStereo(nn.Module):
    def __init__(self, args,
                 refine_win_rs=[2, 2, 1, 1], # refine window radius at 1/32, 1/16, 1/8, 1/4
                 refine_nums=[8, 8, 8, 2],
                 num_heads=[4, 4, 4, 4],
                 mlp_ratios=[2, 2, 2, 2],
                 drop_path=0.):
        super().__init__()
        self.refine_nums = refine_nums

        self.encoder = convformer(args.variant)
        self.channels = self.encoder.dims[::-1] # resolution low to high
        self.num_heads = num_heads
        self.head_dims = [c//h for c, h in zip(self.channels, self.num_heads)]

        self.factor = 2
        self.factor_last = 2**(len(self.channels) - len(refine_nums) + 2)
        
        self.field_dim = 2 # 2(flow)

        self.up_decoders = nn.ModuleList()
        self.up_masks = nn.ModuleList()
        for i in range(len(self.channels)):
            if i > 0:
                self.up_decoders.append(UpConv(self.channels[i-1], self.channels[i]))
                self.up_masks.append(
                    nn.Sequential(
                    nn.Conv2d(self.channels[i-1], self.channels[i-1], 3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(self.channels[i-1], (self.factor**2)*9, 1, padding=0))
                )
            else:
                self.up_decoders.append(nn.Identity())
                self.up_masks.append(nn.Identity())

        self.up_masks.append(
            nn.Sequential(
            nn.Conv2d(self.channels[-1], self.channels[-1]*2, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.channels[-1]*2, (self.factor_last**2)*9, 1, padding=0)))

        dp_rates = [x.item() for x in torch.linspace(0, drop_path, sum(refine_nums))]
        # MatchAttention
        self.match_attentions = nn.ModuleList()
        for i in range(len(refine_nums)):
            self.match_attentions.append(
                MatchAttentionBlock(args, self.channels[i], win_r=refine_win_rs[i], 
                                    num_layer=refine_nums[i], num_head=self.num_heads[i], head_dim=self.head_dims[i], 
                                    mlp_ratio=mlp_ratios[i], field_dim=self.field_dim, 
                                    dp_rates=dp_rates[sum(refine_nums[:i]):sum(refine_nums[:i+1])])
            )

        self.init_correlation_volume = GlobalCorrelation(self.channels[0])

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def upsample_field(self, field, mask, factor):
        ''' Upsample field [H/factor, W/factor, D] -> [H, W, D] using convex combination '''
        B, H, W, D = field.shape
        field = field.permute(0, 3, 1, 2)
        mask = mask.view(B, 1, 9, factor, factor, H, W)
        mask = torch.softmax(mask, dim=2).to(mask.dtype)
        up_flow = F.unfold(field*factor, [3,3], padding=1)
        up_flow = up_flow.view(B, D, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2).to(mask.dtype) # [B, D, 9, factor, factor, H, W]
        up_flow = up_flow.permute(0, 4, 2, 5, 3, 1)
        return up_flow.reshape(B, factor*H, factor*W, D).contiguous()

    def forward(self, img0, img1, stereo=True, init_flow=None):
        ''' Estimate optical flow/disparity between pair of frames, output bi-directional flow/disparity '''
        field_all = []

        img0 = (2 * (img0 / 255.0) - 1.0).contiguous()
        img1 = (2 * (img1 / 255.0) - 1.0).contiguous()

        x = torch.cat((img0, img1), dim=0) # cat in batch dim

        features = self.encoder(x) # [B*2, H, W, C]
        features = features[::-1] # reverse 1/32, 1/16, 1/8, 1/4
        
        for i in range(len(features)): # 1/32, 1/16, 1/8, 1/4
            if i==0:
                if init_flow is None:
                    init_flow, init_cv = self.init_correlation_volume(features[i], stereo=stereo)
                else:
                    init_cv = None

                field = init_flow.clone() # [B, H, W, 2]
                self_rpos = torch.zeros_like(field)
            else:
                features[i] = self.up_decoders[i](features[i-1], features[i])
                up_mask = self.up_masks[i](features[i-1].permute(0, 3, 1, 2)) # [B, C, H, W]
                self_rpos = self.upsample_field(self_rpos, up_mask, self.factor)
                field = self.upsample_field(field, up_mask, self.factor)
                field_all.append({'self':field})

            features[i], self_rpos, field, fields = self.match_attentions[i](features[i], self_rpos, field, stereo=stereo)
            field_all.extend(fields)

        if self.training:
            B = field.shape[0]
            field_up = self.upsample_field(field[:B//2], self.up_masks[-1](features[-1][:B//2].permute(0, 3, 1, 2)), self.factor_last)
            field_up = torch.cat((field_up, field_up), dim=0) # dummy output
        else:
            field_up = self.upsample_field(field, self.up_masks[-1](features[-1].permute(0, 3, 1, 2)), self.factor_last)

        return {
            'init_flow': init_flow,
            'init_cv': init_cv,
            'field_all': field_all,
            'field_up': field_up,
            'self_rpos': self_rpos,
        }