|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from core.utils.utils import bilinear_sampler |
|
|
|
|
|
try: |
|
|
import corr_sampler |
|
|
except: |
|
|
pass |
|
|
|
|
|
try: |
|
|
import alt_cuda_corr |
|
|
except: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class CorrSampler(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, volume, coords, radius): |
|
|
ctx.save_for_backward(volume,coords) |
|
|
ctx.radius = radius |
|
|
corr, = corr_sampler.forward(volume, coords, radius) |
|
|
return corr |
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
volume, coords = ctx.saved_tensors |
|
|
grad_output = grad_output.contiguous() |
|
|
grad_volume, = corr_sampler.backward(volume, coords, grad_output, ctx.radius) |
|
|
return grad_volume, None, None |
|
|
|
|
|
class CorrBlockFast1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.corr_pyramid = [] |
|
|
|
|
|
corr = CorrBlockFast1D.corr(fmap1, fmap2) |
|
|
batch, h1, w1, dim, w2 = corr.shape |
|
|
corr = corr.reshape(batch*h1*w1, dim, 1, w2) |
|
|
for i in range(self.num_levels): |
|
|
self.corr_pyramid.append(corr.view(batch, h1, w1, -1, w2//2**i)) |
|
|
corr = F.avg_pool2d(corr, [1,2], stride=[1,2]) |
|
|
|
|
|
def __call__(self, coords): |
|
|
out_pyramid = [] |
|
|
bz, _, ht, wd = coords.shape |
|
|
coords = coords[:, [0]] |
|
|
for i in range(self.num_levels): |
|
|
corr = CorrSampler.apply(self.corr_pyramid[i].squeeze(3), coords/2**i, self.radius) |
|
|
out_pyramid.append(corr.view(bz, -1, ht, wd)) |
|
|
return torch.cat(out_pyramid, dim=1) |
|
|
|
|
|
@staticmethod |
|
|
def corr(fmap1, fmap2): |
|
|
B, D, H, W1 = fmap1.shape |
|
|
_, _, _, W2 = fmap2.shape |
|
|
fmap1 = fmap1.view(B, D, H, W1) |
|
|
fmap2 = fmap2.view(B, D, H, W2) |
|
|
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) |
|
|
corr = corr.reshape(B, H, W1, 1, W2).contiguous() |
|
|
return corr / torch.sqrt(torch.tensor(D).float()) |
|
|
|
|
|
|
|
|
class PytorchAlternateCorrBlock1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.corr_pyramid = [] |
|
|
self.fmap1 = fmap1 |
|
|
self.fmap2 = fmap2 |
|
|
|
|
|
def corr(self, fmap1, fmap2, coords): |
|
|
B, D, H, W = fmap2.shape |
|
|
|
|
|
xgrid, ygrid = coords.split([1,1], dim=-1) |
|
|
xgrid = 2*xgrid/(W-1) - 1 |
|
|
ygrid = 2*ygrid/(H-1) - 1 |
|
|
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
|
|
output_corr = [] |
|
|
for grid_slice in grid.unbind(3): |
|
|
fmapw_mini = F.grid_sample(fmap2, grid_slice, align_corners=True) |
|
|
corr = torch.sum(fmapw_mini * fmap1, dim=1) |
|
|
output_corr.append(corr) |
|
|
corr = torch.stack(output_corr, dim=1).permute(0,2,3,1) |
|
|
|
|
|
return corr / torch.sqrt(torch.tensor(D).float()) |
|
|
|
|
|
def __call__(self, coords): |
|
|
r = self.radius |
|
|
coords = coords.permute(0, 2, 3, 1) |
|
|
batch, h1, w1, _ = coords.shape |
|
|
fmap1 = self.fmap1 |
|
|
fmap2 = self.fmap2 |
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
dx = torch.zeros(1) |
|
|
dy = torch.linspace(-r, r, 2*r+1) |
|
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) |
|
|
centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone() |
|
|
centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i |
|
|
coords_lvl = centroid_lvl + delta.view(-1, 2) |
|
|
corr = self.corr(fmap1, fmap2, coords_lvl) |
|
|
fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2]) |
|
|
out_pyramid.append(corr) |
|
|
out = torch.cat(out_pyramid, dim=-1) |
|
|
return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
|
|
|
class PytorchAlternateAbsCorrBlock1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.corr_pyramid = [] |
|
|
self.fmap1 = fmap1 |
|
|
|
|
|
self.fmap2_pyramid = [fmap2] |
|
|
for i in range(num_levels): |
|
|
fmap2 = F.avg_pool2d(fmap2, [1, 2], stride=[1, 2]) |
|
|
self.fmap2_pyramid.append(fmap2) |
|
|
|
|
|
def corr(self, fmap1, fmap2, coords): |
|
|
B, C, H, W = fmap1.shape |
|
|
|
|
|
xgrid, ygrid = coords.split([1,1], dim=-1) |
|
|
xgrid = 2*xgrid/(W-1) - 1 |
|
|
ygrid = 2*ygrid/(H-1) - 1 |
|
|
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1) |
|
|
|
|
|
disp_num = 2 * self.radius + 1 |
|
|
fmapw_mini = F.grid_sample(fmap2, grid.view(B, H, W*disp_num, 2), mode='bilinear', |
|
|
padding_mode='zeros').view(B, C, H, W, disp_num) |
|
|
corr = torch.sum(fmap1.unsqueeze(-1) * fmapw_mini, dim=1) |
|
|
|
|
|
return corr / torch.sqrt(torch.tensor(C).float()) |
|
|
|
|
|
def __call__(self, coords): |
|
|
print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB") |
|
|
|
|
|
|
|
|
B, D, H, W = coords.shape |
|
|
if D==1: |
|
|
y_coord = torch.arange(H).unsqueeze(1).float().repeat(B, 1, 1, W).to(coords.device) |
|
|
coords = torch.cat([coords,y_coord], dim=1) |
|
|
|
|
|
r = self.radius |
|
|
coords = coords.permute(0, 2, 3, 1) |
|
|
batch, h1, w1, _ = coords.shape |
|
|
|
|
|
fmap1 = self.fmap1 |
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
fmap2 = self.fmap2_pyramid[i] |
|
|
|
|
|
dx = torch.zeros(1) |
|
|
dy = torch.linspace(-r, r, 2*r+1) |
|
|
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) |
|
|
centroid_lvl = coords.reshape(batch, h1, w1, 1, 2).clone() |
|
|
centroid_lvl[...,0] = centroid_lvl[...,0] / 2**i |
|
|
coords_lvl = centroid_lvl + delta.view(-1, 2) |
|
|
|
|
|
corr = self.corr(fmap1, fmap2, coords_lvl) |
|
|
out_pyramid.append(corr) |
|
|
out = torch.cat(out_pyramid, dim=-1) |
|
|
return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
|
|
|
class CorrBlock1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.corr_pyramid = [] |
|
|
|
|
|
|
|
|
corr = CorrBlock1D.corr(fmap1, fmap2) |
|
|
|
|
|
batch, h1, w1, _, w2 = corr.shape |
|
|
corr = corr.reshape(batch*h1*w1, 1, 1, w2) |
|
|
|
|
|
self.corr_pyramid.append(corr) |
|
|
for i in range(self.num_levels): |
|
|
corr = F.avg_pool2d(corr, [1,2], stride=[1,2]) |
|
|
self.corr_pyramid.append(corr) |
|
|
|
|
|
def __call__(self, coords): |
|
|
r = self.radius |
|
|
coords = coords[:, :1].permute(0, 2, 3, 1) |
|
|
batch, h1, w1, _ = coords.shape |
|
|
|
|
|
|
|
|
|
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
corr = self.corr_pyramid[i] |
|
|
dx = torch.linspace(-r, r, 2*r+1) |
|
|
dx = dx.view(2*r+1, 1).to(coords.device) |
|
|
x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i |
|
|
y0 = torch.zeros_like(x0) |
|
|
|
|
|
coords_lvl = torch.cat([x0,y0], dim=-1) |
|
|
corr = bilinear_sampler(corr, coords_lvl) |
|
|
corr = corr.view(batch, h1, w1, -1) |
|
|
out_pyramid.append(corr) |
|
|
|
|
|
out = torch.cat(out_pyramid, dim=-1) |
|
|
return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
@staticmethod |
|
|
def corr(fmap1, fmap2): |
|
|
B, D, H, W1 = fmap1.shape |
|
|
_, _, _, W2 = fmap2.shape |
|
|
fmap1 = fmap1.view(B, D, H, W1) |
|
|
fmap2 = fmap2.view(B, D, H, W2) |
|
|
corr = torch.einsum('aijk,aijh->ajkh', fmap1, fmap2) |
|
|
corr = corr.reshape(B, H, W1, 1, W2).contiguous() |
|
|
return corr / torch.sqrt(torch.tensor(D).float()) |
|
|
|
|
|
class AbsCorrBlock1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.abs_corr_matrix_pyramid = [] |
|
|
|
|
|
|
|
|
abs_corr_matrix = AbsCorrBlock1D.abs_corr(fmap1, fmap2) |
|
|
|
|
|
batch, h1, w1, _, w2 = abs_corr_matrix.shape |
|
|
abs_corr_matrix = abs_corr_matrix.reshape(batch*h1*w1, 1, 1, w2) |
|
|
|
|
|
self.abs_corr_matrix_pyramid.append(abs_corr_matrix) |
|
|
for i in range(self.num_levels): |
|
|
abs_corr_matrix = F.avg_pool2d(abs_corr_matrix, [1,2], stride=[1,2]) |
|
|
self.abs_corr_matrix_pyramid.append(abs_corr_matrix) |
|
|
|
|
|
def __call__(self, coords): |
|
|
r = self.radius |
|
|
coords = coords[:, :1].permute(0, 2, 3, 1) |
|
|
batch, h1, w1, _ = coords.shape |
|
|
|
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
abs_corr_matrix = self.abs_corr_matrix_pyramid[i] |
|
|
dx = torch.linspace(-r, r, 2*r+1) |
|
|
dx = dx.view(2*r+1, 1).to(coords.device) |
|
|
x0 = dx + coords.reshape(batch*h1*w1, 1, 1, 1) / 2**i |
|
|
y0 = torch.zeros_like(x0) |
|
|
|
|
|
coords_lvl = torch.cat([x0,y0], dim=-1) |
|
|
abs_corr_matrix = bilinear_sampler(abs_corr_matrix, coords_lvl) |
|
|
abs_corr_matrix = abs_corr_matrix.view(batch, h1, w1, -1) |
|
|
out_pyramid.append(abs_corr_matrix) |
|
|
|
|
|
out = torch.cat(out_pyramid, dim=-1) |
|
|
return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
@staticmethod |
|
|
def abs_corr(fmap1, fmap2): |
|
|
"""fucntion: build the correlation matrix (not traditional cost volume) for each pixel in the same line. |
|
|
args: |
|
|
fmap1: feature maps from left view, B*C*H*W1; |
|
|
fmap2: feature maps from right view, B*C*H*W2; |
|
|
return: |
|
|
the correlation matrix, B*H*W1*W2; |
|
|
""" |
|
|
B, D, H, W1 = fmap1.shape |
|
|
_, _, _, W2 = fmap2.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
corr_matrix = (fmap1.unsqueeze(-1) - fmap2.unsqueeze(-2)).abs_().sum(dim=1) |
|
|
|
|
|
print("-"*10, " AbsCorrBlock1D: {} ".format(corr_matrix.shape), "-"*10) |
|
|
print(f"当前显存消耗量: {torch.distributed.get_rank()} {torch.cuda.memory_allocated() / 1024 / 1024:.2f} MB") |
|
|
|
|
|
corr_matrix = corr_matrix.reshape(B, H, W1, 1, W2).contiguous() |
|
|
return corr_matrix / torch.sqrt(torch.tensor(D).float()) |
|
|
|
|
|
class AlternateCorrBlock: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
raise NotImplementedError |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
|
|
|
self.pyramid = [(fmap1, fmap2)] |
|
|
for i in range(self.num_levels): |
|
|
fmap1 = F.avg_pool2d(fmap1, 2, stride=2) |
|
|
fmap2 = F.avg_pool2d(fmap2, 2, stride=2) |
|
|
self.pyramid.append((fmap1, fmap2)) |
|
|
|
|
|
def __call__(self, coords): |
|
|
coords = coords.permute(0, 2, 3, 1) |
|
|
B, H, W, _ = coords.shape |
|
|
dim = self.pyramid[0][0].shape[1] |
|
|
|
|
|
corr_list = [] |
|
|
for i in range(self.num_levels): |
|
|
r = self.radius |
|
|
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() |
|
|
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() |
|
|
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) |
|
|
corr_list.append(corr.squeeze(1)) |
|
|
|
|
|
corr = torch.stack(corr_list, dim=1) |
|
|
corr = corr.reshape(B, -1, H, W) |
|
|
return corr / torch.sqrt(torch.tensor(dim).float()) |
|
|
|