MatchStereo / models /cost_volume.py
Tingman's picture
code release
0940df6
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.convformer import LayerNormWithoutBias
from utils.utils import init_coords
class GlobalCorrelation(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = LayerNormWithoutBias(dim)
self.q = nn.Linear(dim, dim, bias=False)
self.k = nn.Linear(dim, dim, bias=False)
self.scale = dim**-0.5
def forward(self, x, stereo=True):
x = self.norm(x)
ref, tgt = x.chunk(2, dim=0)
ref, tgt = self.q(ref), self.k(tgt)
# global correlation on horizontal direction
B, H, W, C = ref.shape
if stereo:
correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H, W, W]
# mask subsequent positions to make disparity positive
mask = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=1) # [W, W]
valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # [B, H, W, W]
mask_ = torch.triu(torch.ones((W, W), dtype=ref.dtype, device=ref.device), diagonal=0) # mask for input order [right, left]
valid_mask_ = (mask_ != 0).unsqueeze(0).unsqueeze(0).repeat(B, H, 1, 1) # upper right
valid_mask = torch.cat((valid_mask, valid_mask_), dim=0) # [B*2, H, W, W]
correlation = torch.cat((correlation, correlation.permute(0, 1, 3, 2)), dim=0) # [B*2, H, W, W]
B = B*2
correlation[~valid_mask] = -1e9 if correlation.dtype == torch.float32 else -1e4
# build volume from correlation
D = W # all-pair correlation
volume = correlation.new_zeros([B, D, H, W])
for d in range(D): # most time-consuming
volume[:B//2, d, :, d:] = correlation[:B//2, :, range(d, W), range(W-d)]
volume[B//2:, d, :, :(W-d)] = correlation[B//2:, :, range(W-d), range(d, W)]
volume = F.softmax(volume, dim=1).to(volume.dtype)
volume_clone = volume.clone()
for d in range(D): # fill out of view # second time-consuming
volume_clone[:B//2, d, :, :d] = volume[:B//2, d, :, d:d+1] # left
volume_clone[B//2:, d, :, W-1-d:] = volume[B//2:, d, :, W-1-d:(W-d)] # right
flow = local_disparity_estimator(volume_clone)
return flow, volume_clone
else:
init_grid = init_coords(ref) # [B, H, W, 2]
ref = ref.view(B, -1, C) # [B, H*W, C]
tgt = tgt.view(B, -1, C) # [B, H*W, C]
correlation = torch.matmul(ref, tgt.transpose(-2, -1))*self.scale # [B, H*W, H*W]
correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, H, W, 2]
B = B * 2
prob = F.softmax(correlation, dim=-1).to(correlation.dtype) # [B, H*W, H*W]
flow = local_flow_estimator(prob, init_grid)
return flow, prob.view(B, H, W, H*W)
def local_flow_estimator(prob, init_grid, k=5):
"""
Flow estimator using weighted sum within local window centered at max prob
Args:
prob: normalized correlation volume [B, H*W, H*W]
init_grid: init coordinate grid [B, H, W, 2]
k: local window size (odd number)
Returns:
flow: optical field [B, H, W, 2]
"""
B, H, W, _ = init_grid.shape
r = k // 2
device = prob.device
prob_blur = F.avg_pool2d(prob, kernel_size=k, stride=1, padding=r).view(B, H*W, H*W)
max_prob, max_idx = torch.max(prob_blur, dim=-1) # [B, H*W]
max_idx = max_idx.unsqueeze(-1) # [B, H*W, 1]
target_coords = init_grid # [B, H, W, 2]
max_y = max_idx // W # [B, H*W, 1]
max_x = max_idx % W # [B, H*W, 1]
max_y = torch.clamp(max_y, r, H-1-r)
max_x = torch.clamp(max_x, r, W-1-r)
yy, xx = torch.meshgrid(torch.arange(-r, r+1, device=device), torch.arange(-r, r+1, device=device), indexing='ij')
offsets_y = yy.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
offsets_x = xx.reshape(1, 1, k*k, 1) # [1, 1, k*k, 1]
sample_y = max_y.unsqueeze(2) + offsets_y # [B, H*W, k*k, 1]
sample_x = max_x.unsqueeze(2) + offsets_x # [B, H*W, k*k, 1]
sample_y = sample_y.long().squeeze(-1) # [B, H*W, k*k]
sample_x = sample_x.long().squeeze(-1) # [B, H*W, k*k]
batch_idx = torch.arange(B, device=device).view(B, 1, 1).expand(-1, H*W, k*k)
window_coords = target_coords[batch_idx, sample_y, sample_x] # [B, H*W, k*k, 2]
window_indices = sample_y * W + sample_x # [B, H*W, k*k]
window_probs = torch.gather(prob, dim=-1, index=window_indices) # [B, H*W, k*k]
mean_prob = 1.0 / (H * W)
invalid_mask = window_probs < mean_prob
window_probs[invalid_mask] = 0
window_probs_sum = window_probs.sum(dim=-1, keepdim=True).to(window_probs.dtype)
window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
normalized_probs = window_probs / window_probs_sum # [B, H*W, k*k]
normalized_probs = normalized_probs.unsqueeze(-1) # [B, H*W, k*k, 1]
correspondence = torch.sum(normalized_probs * window_coords, dim=2).to(normalized_probs.dtype) # [B, H*W, 2]
correspondence = correspondence.view(B, H, W, 2) # [B, H, W, 2]
flow = correspondence - init_grid
return flow
def local_disparity_estimator(cv, k=5):
"""
Disparity estimator using weighted sum within local window centered at max prob
Args:
cv: cost volume [B, D, H, W]
k: local window size (odd number)
Returns:
flow: [B, H, W, 2]
"""
B, D, H, W = cv.shape
r = k // 2
device = cv.device
cv_blur = F.avg_pool1d(cv.permute(0, 2, 3, 1).view(B, -1, D), kernel_size=k, stride=1, padding=r).view(B, H, W, D).permute(0, 3, 1, 2)
# find max idx in blured cv
max_cv, max_idx = torch.max(cv_blur, dim=1) # max_idx: [B, H, W]
max_idx = max_idx.unsqueeze(1) # [B, 1, H, W]
max_idx = torch.clamp(max_idx, r, D-1-r) # [B, 1, H, W]
offsets = torch.arange(-r, r+1, device=device).view(1, k, 1, 1) # [1, k, 1, 1]
sample_idx = max_idx + offsets # [B, k, H, W]
sample_idx = torch.clamp(sample_idx, 0, D-1)
batch_idx = torch.arange(B, device=device).view(B, 1, 1, 1).expand(-1, k, H, W)
h_idx = torch.arange(H, device=device).view(1, 1, H, 1).expand(B, k, H, W)
w_idx = torch.arange(W, device=device).view(1, 1, 1, W).expand(B, k, H, W)
window_probs = cv[batch_idx, sample_idx, h_idx, w_idx] # [B, k, H, W]
mean_prob = 1.0 / D
invalid_mask = window_probs < mean_prob
window_probs[invalid_mask] = 0
# normalize within local window
window_probs_sum = window_probs.sum(dim=1, keepdim=True).to(window_probs.dtype) # [B, 1, H, W]
window_probs_sum = torch.clamp(window_probs_sum, min=torch.finfo(window_probs_sum.dtype).tiny)
normalized_probs = window_probs / window_probs_sum # [B, k, H, W]
window_disp = sample_idx.to(normalized_probs.dtype) # [B, k, H, W]
disp = torch.sum(normalized_probs * window_disp, dim=1).to(normalized_probs.dtype).unsqueeze(-1) # [B, H, W, 1]
return disp_to_flow(disp, B)
def disp_to_flow(disp, B):
## disp[:B//2, ...] = -disp[:B//2, ...] # negetive left flow
## for onnx support
batch_indices = torch.arange(B, device=disp.device)
mask = batch_indices < (B // 2)
disp = torch.where(mask.view(B, 1, 1, 1), -disp, disp)
flow = torch.cat((disp, torch.zeros_like(disp)), dim=-1).contiguous() # [B, H, W, 2]
return flow