| import torch |
| import torch.nn.functional as F |
| from .utils.utils import bilinear_sampler, coords_grid |
|
|
| try: |
| import alt_cuda_corr |
| except: |
| |
| pass |
|
|
|
|
| class CorrBlock: |
| def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
| self.num_levels = num_levels |
| self.radius = radius |
| self.corr_pyramid = [] |
|
|
| |
| corr = CorrBlock.corr(fmap1, fmap2) |
|
|
| batch, h1, w1, dim, h2, w2 = corr.shape |
| corr = corr.reshape(batch*h1*w1, dim, h2, w2) |
|
|
| self.corr_pyramid.append(corr) |
| for i in range(self.num_levels-1): |
| corr = F.avg_pool2d(corr, 2, stride=2) |
| self.corr_pyramid.append(corr) |
|
|
| def __call__(self, coords): |
| r = self.radius |
| coords = coords.permute(0, 2, 3, 1) |
| batch, h1, w1, _ = coords.shape |
|
|
| out_pyramid = [] |
| for i in range(self.num_levels): |
| corr = self.corr_pyramid[i] |
| dx = torch.linspace(-r, r, 2*r+1) |
| dy = torch.linspace(-r, r, 2*r+1) |
| delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) |
|
|
| centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i |
| delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) |
| coords_lvl = centroid_lvl + delta_lvl |
|
|
| corr = bilinear_sampler(corr, coords_lvl) |
| corr = corr.view(batch, h1, w1, -1) |
| out_pyramid.append(corr) |
|
|
| out = torch.cat(out_pyramid, dim=-1) |
| return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
| @staticmethod |
| def corr(fmap1, fmap2): |
| batch, dim, ht, wd = fmap1.shape |
| fmap1 = fmap1.view(batch, dim, ht*wd) |
| fmap2 = fmap2.view(batch, dim, ht*wd) |
|
|
| corr = torch.matmul(fmap1.transpose(1,2), fmap2) |
| corr = corr.view(batch, ht, wd, 1, ht, wd) |
| return corr / torch.sqrt(torch.tensor(dim).float()) |
|
|
|
|
| class CorrLayer(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, fmap1, fmap2, coords, r): |
| fmap1 = fmap1.contiguous() |
| fmap2 = fmap2.contiguous() |
| coords = coords.contiguous() |
| ctx.save_for_backward(fmap1, fmap2, coords) |
| ctx.r = r |
| corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) |
| return corr |
|
|
| @staticmethod |
| def backward(ctx, grad_corr): |
| fmap1, fmap2, coords = ctx.saved_tensors |
| grad_corr = grad_corr.contiguous() |
| fmap1_grad, fmap2_grad, coords_grad = \ |
| correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) |
| return fmap1_grad, fmap2_grad, coords_grad, None |
|
|
|
|
| class AlternateCorrBlock: |
| def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
| self.num_levels = num_levels |
| self.radius = radius |
|
|
| self.pyramid = [(fmap1, fmap2)] |
| for i in range(self.num_levels): |
| fmap1 = F.avg_pool2d(fmap1, 2, stride=2) |
| fmap2 = F.avg_pool2d(fmap2, 2, stride=2) |
| self.pyramid.append((fmap1, fmap2)) |
|
|
| def __call__(self, coords): |
|
|
| coords = coords.permute(0, 2, 3, 1) |
| B, H, W, _ = coords.shape |
|
|
| corr_list = [] |
| for i in range(self.num_levels): |
| r = self.radius |
| fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) |
| fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) |
|
|
| coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() |
| corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) |
| corr_list.append(corr.squeeze(1)) |
|
|
| corr = torch.stack(corr_list, dim=1) |
| corr = corr.reshape(B, -1, H, W) |
| return corr / 16.0 |
|
|