MatchStereo / models /match_stereo.py
Tingman's picture
code release
0940df6
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from models.common import UpConv
from models.convformer import convformer
from models.attention_blocks import MatchAttentionBlock
from models.cost_volume import GlobalCorrelation
class MatchStereo(nn.Module):
def __init__(self, args,
refine_win_rs=[2, 2, 1, 1], # refine window radius at 1/32, 1/16, 1/8, 1/4
refine_nums=[8, 8, 8, 2],
num_heads=[4, 4, 4, 4],
mlp_ratios=[2, 2, 2, 2],
drop_path=0.):
super().__init__()
self.refine_nums = refine_nums
self.encoder = convformer(args.variant)
self.channels = self.encoder.dims[::-1] # resolution low to high
self.num_heads = num_heads
self.head_dims = [c//h for c, h in zip(self.channels, self.num_heads)]
self.factor = 2
self.factor_last = 2**(len(self.channels) - len(refine_nums) + 2)
self.field_dim = 2 # 2(flow)
self.up_decoders = nn.ModuleList()
self.up_masks = nn.ModuleList()
for i in range(len(self.channels)):
if i > 0:
self.up_decoders.append(UpConv(self.channels[i-1], self.channels[i]))
self.up_masks.append(
nn.Sequential(
nn.Conv2d(self.channels[i-1], self.channels[i-1], 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.channels[i-1], (self.factor**2)*9, 1, padding=0))
)
else:
self.up_decoders.append(nn.Identity())
self.up_masks.append(nn.Identity())
self.up_masks.append(
nn.Sequential(
nn.Conv2d(self.channels[-1], self.channels[-1]*2, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.channels[-1]*2, (self.factor_last**2)*9, 1, padding=0)))
dp_rates = [x.item() for x in torch.linspace(0, drop_path, sum(refine_nums))]
# MatchAttention
self.match_attentions = nn.ModuleList()
for i in range(len(refine_nums)):
self.match_attentions.append(
MatchAttentionBlock(args, self.channels[i], win_r=refine_win_rs[i],
num_layer=refine_nums[i], num_head=self.num_heads[i], head_dim=self.head_dims[i],
mlp_ratio=mlp_ratios[i], field_dim=self.field_dim,
dp_rates=dp_rates[sum(refine_nums[:i]):sum(refine_nums[:i+1])])
)
self.init_correlation_volume = GlobalCorrelation(self.channels[0])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def upsample_field(self, field, mask, factor):
''' Upsample field [H/factor, W/factor, D] -> [H, W, D] using convex combination '''
B, H, W, D = field.shape
field = field.permute(0, 3, 1, 2)
mask = mask.view(B, 1, 9, factor, factor, H, W)
mask = torch.softmax(mask, dim=2).to(mask.dtype)
up_flow = F.unfold(field*factor, [3,3], padding=1)
up_flow = up_flow.view(B, D, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2).to(mask.dtype) # [B, D, 9, factor, factor, H, W]
up_flow = up_flow.permute(0, 4, 2, 5, 3, 1)
return up_flow.reshape(B, factor*H, factor*W, D).contiguous()
def forward(self, img0, img1, stereo=True, init_flow=None):
''' Estimate optical flow/disparity between pair of frames, output bi-directional flow/disparity '''
field_all = []
img0 = (2 * (img0 / 255.0) - 1.0).contiguous()
img1 = (2 * (img1 / 255.0) - 1.0).contiguous()
x = torch.cat((img0, img1), dim=0) # cat in batch dim
features = self.encoder(x) # [B*2, H, W, C]
features = features[::-1] # reverse 1/32, 1/16, 1/8, 1/4
for i in range(len(features)): # 1/32, 1/16, 1/8, 1/4
if i==0:
if init_flow is None:
init_flow, init_cv = self.init_correlation_volume(features[i], stereo=stereo)
else:
init_cv = None
field = init_flow.clone() # [B, H, W, 2]
self_rpos = torch.zeros_like(field)
else:
features[i] = self.up_decoders[i](features[i-1], features[i])
up_mask = self.up_masks[i](features[i-1].permute(0, 3, 1, 2)) # [B, C, H, W]
self_rpos = self.upsample_field(self_rpos, up_mask, self.factor)
field = self.upsample_field(field, up_mask, self.factor)
field_all.append({'self':field})
features[i], self_rpos, field, fields = self.match_attentions[i](features[i], self_rpos, field, stereo=stereo)
field_all.extend(fields)
if self.training:
B = field.shape[0]
field_up = self.upsample_field(field[:B//2], self.up_masks[-1](features[-1][:B//2].permute(0, 3, 1, 2)), self.factor_last)
field_up = torch.cat((field_up, field_up), dim=0) # dummy output
else:
field_up = self.upsample_field(field, self.up_masks[-1](features[-1].permute(0, 3, 1, 2)), self.factor_last)
return {
'init_flow': init_flow,
'init_cv': init_cv,
'field_all': field_all,
'field_up': field_up,
'self_rpos': self_rpos,
}