Spaces:
Runtime error
Runtime error
Commit
·
efb54b3
1
Parent(s):
c887498
Create raft_core_corr.py
Browse files- raft_core_corr.py +91 -0
raft_core_corr.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from raft_core_utils_utils import bilinear_sampler, coords_grid
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import alt_cuda_corr
|
| 7 |
+
except:
|
| 8 |
+
# alt_cuda_corr is not compiled
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CorrBlock:
|
| 13 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 14 |
+
self.num_levels = num_levels
|
| 15 |
+
self.radius = radius
|
| 16 |
+
self.corr_pyramid = []
|
| 17 |
+
|
| 18 |
+
# all pairs correlation
|
| 19 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
| 20 |
+
|
| 21 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
| 22 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
| 23 |
+
|
| 24 |
+
self.corr_pyramid.append(corr)
|
| 25 |
+
for i in range(self.num_levels-1):
|
| 26 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
| 27 |
+
self.corr_pyramid.append(corr)
|
| 28 |
+
|
| 29 |
+
def __call__(self, coords):
|
| 30 |
+
r = self.radius
|
| 31 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 32 |
+
batch, h1, w1, _ = coords.shape
|
| 33 |
+
|
| 34 |
+
out_pyramid = []
|
| 35 |
+
for i in range(self.num_levels):
|
| 36 |
+
corr = self.corr_pyramid[i]
|
| 37 |
+
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
| 38 |
+
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
|
| 39 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
| 40 |
+
|
| 41 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
| 42 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
| 43 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 44 |
+
|
| 45 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
| 46 |
+
corr = corr.view(batch, h1, w1, -1)
|
| 47 |
+
out_pyramid.append(corr)
|
| 48 |
+
|
| 49 |
+
out = torch.cat(out_pyramid, dim=-1)
|
| 50 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
| 51 |
+
|
| 52 |
+
@staticmethod
|
| 53 |
+
def corr(fmap1, fmap2):
|
| 54 |
+
batch, dim, ht, wd = fmap1.shape
|
| 55 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
| 56 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
| 57 |
+
|
| 58 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
| 59 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
| 60 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AlternateCorrBlock:
|
| 64 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
| 65 |
+
self.num_levels = num_levels
|
| 66 |
+
self.radius = radius
|
| 67 |
+
|
| 68 |
+
self.pyramid = [(fmap1, fmap2)]
|
| 69 |
+
for i in range(self.num_levels):
|
| 70 |
+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
| 71 |
+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
| 72 |
+
self.pyramid.append((fmap1, fmap2))
|
| 73 |
+
|
| 74 |
+
def __call__(self, coords):
|
| 75 |
+
coords = coords.permute(0, 2, 3, 1)
|
| 76 |
+
B, H, W, _ = coords.shape
|
| 77 |
+
dim = self.pyramid[0][0].shape[1]
|
| 78 |
+
|
| 79 |
+
corr_list = []
|
| 80 |
+
for i in range(self.num_levels):
|
| 81 |
+
r = self.radius
|
| 82 |
+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
| 83 |
+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
| 84 |
+
|
| 85 |
+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
| 86 |
+
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
| 87 |
+
corr_list.append(corr.squeeze(1))
|
| 88 |
+
|
| 89 |
+
corr = torch.stack(corr_list, dim=1)
|
| 90 |
+
corr = corr.reshape(B, -1, H, W)
|
| 91 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|