import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_ from models.common import UpConv from models.convformer import convformer from models.attention_blocks import MatchAttentionBlock from models.cost_volume import GlobalCorrelation class MatchStereo(nn.Module): def __init__(self, args, refine_win_rs=[2, 2, 1, 1], # refine window radius at 1/32, 1/16, 1/8, 1/4 refine_nums=[8, 8, 8, 2], num_heads=[4, 4, 4, 4], mlp_ratios=[2, 2, 2, 2], drop_path=0.): super().__init__() self.refine_nums = refine_nums self.encoder = convformer(args.variant) self.channels = self.encoder.dims[::-1] # resolution low to high self.num_heads = num_heads self.head_dims = [c//h for c, h in zip(self.channels, self.num_heads)] self.factor = 2 self.factor_last = 2**(len(self.channels) - len(refine_nums) + 2) self.field_dim = 2 # 2(flow) self.up_decoders = nn.ModuleList() self.up_masks = nn.ModuleList() for i in range(len(self.channels)): if i > 0: self.up_decoders.append(UpConv(self.channels[i-1], self.channels[i])) self.up_masks.append( nn.Sequential( nn.Conv2d(self.channels[i-1], self.channels[i-1], 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.channels[i-1], (self.factor**2)*9, 1, padding=0)) ) else: self.up_decoders.append(nn.Identity()) self.up_masks.append(nn.Identity()) self.up_masks.append( nn.Sequential( nn.Conv2d(self.channels[-1], self.channels[-1]*2, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(self.channels[-1]*2, (self.factor_last**2)*9, 1, padding=0))) dp_rates = [x.item() for x in torch.linspace(0, drop_path, sum(refine_nums))] # MatchAttention self.match_attentions = nn.ModuleList() for i in range(len(refine_nums)): self.match_attentions.append( MatchAttentionBlock(args, self.channels[i], win_r=refine_win_rs[i], num_layer=refine_nums[i], num_head=self.num_heads[i], head_dim=self.head_dims[i], mlp_ratio=mlp_ratios[i], field_dim=self.field_dim, dp_rates=dp_rates[sum(refine_nums[:i]):sum(refine_nums[:i+1])]) ) self.init_correlation_volume = GlobalCorrelation(self.channels[0]) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def upsample_field(self, field, mask, factor): ''' Upsample field [H/factor, W/factor, D] -> [H, W, D] using convex combination ''' B, H, W, D = field.shape field = field.permute(0, 3, 1, 2) mask = mask.view(B, 1, 9, factor, factor, H, W) mask = torch.softmax(mask, dim=2).to(mask.dtype) up_flow = F.unfold(field*factor, [3,3], padding=1) up_flow = up_flow.view(B, D, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2).to(mask.dtype) # [B, D, 9, factor, factor, H, W] up_flow = up_flow.permute(0, 4, 2, 5, 3, 1) return up_flow.reshape(B, factor*H, factor*W, D).contiguous() def forward(self, img0, img1, stereo=True, init_flow=None): ''' Estimate optical flow/disparity between pair of frames, output bi-directional flow/disparity ''' field_all = [] img0 = (2 * (img0 / 255.0) - 1.0).contiguous() img1 = (2 * (img1 / 255.0) - 1.0).contiguous() x = torch.cat((img0, img1), dim=0) # cat in batch dim features = self.encoder(x) # [B*2, H, W, C] features = features[::-1] # reverse 1/32, 1/16, 1/8, 1/4 for i in range(len(features)): # 1/32, 1/16, 1/8, 1/4 if i==0: if init_flow is None: init_flow, init_cv = self.init_correlation_volume(features[i], stereo=stereo) else: init_cv = None field = init_flow.clone() # [B, H, W, 2] self_rpos = torch.zeros_like(field) else: features[i] = self.up_decoders[i](features[i-1], features[i]) up_mask = self.up_masks[i](features[i-1].permute(0, 3, 1, 2)) # [B, C, H, W] self_rpos = self.upsample_field(self_rpos, up_mask, self.factor) field = self.upsample_field(field, up_mask, self.factor) field_all.append({'self':field}) features[i], self_rpos, field, fields = self.match_attentions[i](features[i], self_rpos, field, stereo=stereo) field_all.extend(fields) if self.training: B = field.shape[0] field_up = self.upsample_field(field[:B//2], self.up_masks[-1](features[-1][:B//2].permute(0, 3, 1, 2)), self.factor_last) field_up = torch.cat((field_up, field_up), dim=0) # dummy output else: field_up = self.upsample_field(field, self.up_masks[-1](features[-1].permute(0, 3, 1, 2)), self.factor_last) return { 'init_flow': init_flow, 'init_cv': init_cv, 'field_all': field_all, 'field_up': field_up, 'self_rpos': self_rpos, }