Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,471 Bytes
8ca4dce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import bilinear_sampler, coords_grid, manual_pad
class AGCL:
"""
Implementation of Adaptive Group Correlation Layer (AGCL).
"""
def __init__(self, fmap1, fmap2, att=None):
self.fmap1 = fmap1
self.fmap2 = fmap2
self.att = att
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
if iter_mode:
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
else:
corr = self.corr_att_offset(
self.fmap1, self.fmap2, flow, extra_offset, small_patch
)
return corr
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
N, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
right_pad = manual_pad(right_feature, pady, padx)
corr_list = []
for h in range(0, pady * 2 + 1, di_y):
for w in range(0, padx * 2 + 1, di_x):
right_crop = right_pad[:, :, h : h + H, w : w + W]
assert right_crop.shape == left_feature.shape
corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
corr_list.append(corr)
corr_final = torch.cat(corr_list, dim=1)
return corr_final
def corr_iter(self, left_feature, right_feature, flow, small_patch):
coords = self.coords + flow
coords = coords.permute(0, 2, 3, 1)
right_feature = bilinear_sampler(right_feature, coords)
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
N, C, H, W = left_feature.shape
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
corrs = []
for i in range(len(psize_list)):
corr = self.get_correlation(
lefts[i], rights[i], psize_list[i], dilate_list[i]
)
corrs.append(corr)
final_corr = torch.cat(corrs, dim=1)
return final_corr
def corr_att_offset(
self, left_feature, right_feature, flow, extra_offset, small_patch
):
N, C, H, W = left_feature.shape
if self.att is not None:
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
# 'n (h w) c -> n c h w'
left_feature, right_feature = self.att(left_feature, right_feature)
# 'n (h w) c -> n c h w'
left_feature, right_feature = [
x.reshape(N, H, W, C).permute(0, 3, 1, 2)
for x in [left_feature, right_feature]
]
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
C = C // 4
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
search_num = 9
extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2]
corrs = []
for i in range(len(psize_list)):
left_feature, right_feature = lefts[i], rights[i]
psize, dilate = psize_list[i], dilate_list[i]
psizey, psizex = psize[0], psize[1]
dilatey, dilatex = dilate[0], dilate[1]
ry = psizey // 2 * dilatey
rx = psizex // 2 * dilatex
x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy')
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in sorted((0, 2, 3)):
offsets = offsets.unsqueeze(d)
offsets = offsets.repeat_interleave(N, dim=0)
offsets = offsets + extra_offset
coords = self.coords + flow # [N, 2, H, W]
coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2]
coords = torch.unsqueeze(coords, 1) + offsets
coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2]
right_feature = bilinear_sampler(
right_feature, coords
) # [N, C, search_num*H, W]
right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W]
left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2)
corr = torch.mean(left_feature * right_feature, dim=1)
corrs.append(corr)
final_corr = torch.cat(corrs, dim=1)
return final_corr
|