|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def bilinear_sampler(img, coords, mode="bilinear", mask=False, stereo=True): |
|
|
"""Wrapper for grid_sample, uses pixel coordinates""" |
|
|
H, W = img.shape[-2:] |
|
|
xgrid, ygrid = coords.split([1, 1], dim=-1) |
|
|
xgrid = 2 * xgrid / (W - 1) - 1 |
|
|
if not stereo: |
|
|
ygrid = 2 * ygrid / (H - 1) - 1 |
|
|
else: |
|
|
assert torch.unique(ygrid).numel() == 1 and H == 1 |
|
|
img = img.contiguous() |
|
|
grid = torch.cat([xgrid, ygrid], dim=-1).contiguous() |
|
|
img = F.grid_sample(img, grid, align_corners=True) |
|
|
|
|
|
if mask: |
|
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) |
|
|
return img, mask.float() |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
def coords_grid(batch, ht, wd, device): |
|
|
coords = torch.meshgrid( |
|
|
torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij" |
|
|
) |
|
|
coords = torch.stack(coords[::-1], dim=0).float() |
|
|
return coords[None].repeat(batch, 1, 1, 1) |
|
|
|
|
|
|
|
|
class CorrBlock1D: |
|
|
def __init__(self, fmap1, fmap2, num_levels=4, radius=4): |
|
|
self.num_levels = num_levels |
|
|
self.radius = radius |
|
|
self.corr_pyramid = [] |
|
|
self.coords = coords_grid( |
|
|
fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device |
|
|
) |
|
|
|
|
|
corr = CorrBlock1D.corr(fmap1, fmap2) |
|
|
|
|
|
batch, h1, w1, dim, w2 = corr.shape |
|
|
corr = corr.reshape(batch * h1 * w1, dim, 1, w2) |
|
|
|
|
|
self.corr_pyramid.append(corr) |
|
|
for i in range(self.num_levels): |
|
|
corr = F.avg_pool2d(corr, [1, 2], stride=[1, 2]) |
|
|
self.corr_pyramid.append(corr) |
|
|
|
|
|
def __call__(self, flow): |
|
|
r = self.radius |
|
|
coords = self.coords + flow |
|
|
coords = coords[:, :1].permute(0, 2, 3, 1) |
|
|
batch, h1, w1, _ = coords.shape |
|
|
|
|
|
out_pyramid = [] |
|
|
for i in range(self.num_levels): |
|
|
corr = self.corr_pyramid[i] |
|
|
dx = torch.linspace(-r, r, 2 * r + 1) |
|
|
dx = dx.view(1, 1, 2 * r + 1, 1).to(coords.device) |
|
|
x0 = dx + coords.reshape(batch * h1 * w1, 1, 1, 1) / 2 ** i |
|
|
y0 = torch.zeros_like(x0) |
|
|
|
|
|
coords_lvl = torch.cat([x0, y0], dim=-1) |
|
|
corr = bilinear_sampler(corr, coords_lvl) |
|
|
corr = corr.view(batch, h1, w1, -1) |
|
|
out_pyramid.append(corr) |
|
|
|
|
|
out = torch.cat(out_pyramid, dim=-1) |
|
|
return out.permute(0, 3, 1, 2).contiguous().float() |
|
|
|
|
|
@staticmethod |
|
|
def corr(fmap1, fmap2): |
|
|
B, D, H, W1 = fmap1.shape |
|
|
_, _, _, W2 = fmap2.shape |
|
|
fmap1 = fmap1.view(B, D, H, W1) |
|
|
fmap2 = fmap2.view(B, D, H, W2) |
|
|
corr = torch.einsum("aijk,aijh->ajkh", fmap1, fmap2) |
|
|
corr = corr.reshape(B, H, W1, 1, W2).contiguous() |
|
|
return corr / torch.sqrt(torch.tensor(D).float()) |
|
|
|