BFZD233
initial
f06f310
import torch
import torch.nn.functional as F
from core.utils.utils import bilinear_sampler
try:
import corr_sampler
except:
pass
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrSampler(torch.autograd.Function):
@staticmethod
def forward(ctx, volume, coords, radius):
ctx.save_for_backward(volume,coords)
ctx.radius = radius
corr, = corr_sampler.forward(volume, coords, radius)
return corr
@staticmethod
def backward(ctx, grad_output):
volume, coords = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius)
return grad_volume, None, None
class CorrBlockFast1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlockFast1D.corr(fmap1, fmap2)
batch, h1, w1, dim, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, 1, w2)
for i in range(self.num_levels):
self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i))
corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
def __call__(self, coords):
out_pyramid = []
bz, _, ht, wd = coords.shape
coords = coords[:, [0]]
for i in range(self.num_levels):
corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius)
out_pyramid.append(corr.view(bz, -1, ht, wd))
return torch.cat(out_pyramid, dim=1)
@staticmethod
def corr(fmap1, fmap2):
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1)
fmap2 = fmap2.view(B, D, H, W2)
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
corr = corr.reshape(B, H, W1, 1, W2).contiguous()
return corr / torch.sqrt(torch.tensor(D).float())
class PytorchAlternateCorrBlock1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.fmap1 = fmap1
self.fmap2 = fmap2
def corr(self, fmap1, fmap2, coords):
B, D, H, W = fmap2.shape
# map grid coordinates to [-1,1]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
output_corr = []
for grid_slice in grid.unbind(3):
fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True)
corr = torch.sum(fmapw_mini * fmap1, dim=1)
output_corr.append(corr)
corr = torch.stack(output_corr, dim=1).permute(0,2,3,1)
return corr / torch.sqrt(torch.tensor(D).float())
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
fmap1 = self.fmap1
fmap2 = self.fmap2
out_pyramid = []
for i in range(self.num_levels):
dx = torch.zeros(1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
coords_lvl = centroid_lvl + delta.view(-1, 2)
corr = self.corr(fmap1, fmap2, coords_lvl)
fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
class PytorchAlternateAbsCorrBlock1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
self.fmap1 = fmap1
self.fmap2_pyramid = [fmap2]
for i in range(num_levels):
fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2])
self.fmap2_pyramid.append(fmap2)
def corr(self, fmap1, fmap2, coords):
B, C, H, W = fmap1.shape
# map grid coordinates to [-1,1]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
disp_num = 2 * self.radius + 1
fmapw_mini = F.grid_sample(fmap2, grid.view(B, H, W*disp_num, 2), mode='bilinear',
padding_mode='zeros').view(B, C, H, W, disp_num) # (B, C, H, W, S)
corr = torch.sum(fmap1.unsqueeze(-1) * fmapw_mini, dim=1)
return corr / torch.sqrt(torch.tensor(C).float())
def __call__(self, coords):
print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
# in case of only disparity used in coordinates
B, D, H, W = coords.shape
if D==1:
y_coord = torch.arange(H).unsqueeze(1).float().repeat(B, 1, 1, W).to(coords.device)
coords = torch.cat([coords,y_coord], dim=1)
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
fmap1 = self.fmap1
out_pyramid = []
for i in range(self.num_levels):
fmap2 = self.fmap2_pyramid[i]
dx = torch.zeros(1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone()
centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i
coords_lvl = centroid_lvl + delta.view(-1, 2)
corr = self.corr(fmap1, fmap2, coords_lvl)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
class CorrBlock1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock1D.corr(fmap1, fmap2)
batch, h1, w1, _, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, 1, 1, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels):
corr = F.avg_pool2d(corr, [1,2], stride=[1,2])
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords[:, :1].permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
# print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(2*r+1, 1).to(coords.device)
x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
coords_lvl = torch.cat([x0,y0], dim=-1)
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
fmap1 = fmap1.view(B, D, H, W1)
fmap2 = fmap2.view(B, D, H, W2)
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
corr = corr.reshape(B, H, W1, 1, W2).contiguous()
return corr / torch.sqrt(torch.tensor(D).float())
class AbsCorrBlock1D:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.abs_corr_matrix_pyramid = []
# all pairs correlation
abs_corr_matrix = AbsCorrBlock1D.abs_corr(fmap1, fmap2)
batch, h1, w1, _, w2 = abs_corr_matrix.shape
abs_corr_matrix = abs_corr_matrix.reshape(batch*h1*w1, 1, 1, w2)
self.abs_corr_matrix_pyramid.append(abs_corr_matrix)
for i in range(self.num_levels):
abs_corr_matrix = F.avg_pool2d(abs_corr_matrix, [1,2], stride=[1,2])
self.abs_corr_matrix_pyramid.append(abs_corr_matrix)
def __call__(self, coords):
r = self.radius
coords = coords[:, :1].permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
abs_corr_matrix = self.abs_corr_matrix_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dx = dx.view(2*r+1, 1).to(coords.device)
x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i
y0 = torch.zeros_like(x0)
coords_lvl = torch.cat([x0,y0], dim=-1)
abs_corr_matrix = bilinear_sampler(abs_corr_matrix, coords_lvl)
abs_corr_matrix = abs_corr_matrix.view(batch, h1, w1, -1)
out_pyramid.append(abs_corr_matrix)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def abs_corr(fmap1, fmap2):
"""fucntion: build the correlation matrix (not traditional cost volume) for each pixel in the same line.
args:
fmap1: feature maps from left view, B*C*H*W1;
fmap2: feature maps from right view, B*C*H*W2;
return:
the correlation matrix, B*H*W1*W2;
"""
B, D, H, W1 = fmap1.shape
_, _, _, W2 = fmap2.shape
# 计算 L1 匹配代价
# corr_matrix = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2)
# corr_matrix = torch.sum(torch.abs(fmap1.unsqueeze(-1) - fmap2.unsqueeze(-2)), dim=1) # shape (B, H, W1, W2)
corr_matrix = (fmap1.unsqueeze(-1) - fmap2.unsqueeze(-2)).abs_().sum(dim=1) # shape (B, H, W1, W2)
# corr_matrix = fmap1.sum(dim=1).unsqueeze(-1) - fmap2.sum(dim=1).unsqueeze(-2) # shape (B, H, W1, W2)
print("-"*10, " AbsCorrBlock1D: {} ".format(corr_matrix.shape), "-"*10)
print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB")
corr_matrix = corr_matrix.reshape(B, H, W1, 1, W2).contiguous()
return corr_matrix / torch.sqrt(torch.tensor(D).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
raise NotImplementedError
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())